mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 08:14:05 +01:00
extract usage
This commit is contained in:
parent
07852f9ee7
commit
c296f3b533
@ -14,6 +14,7 @@ from memory.common.llms.base import (
|
|||||||
MessageRole,
|
MessageRole,
|
||||||
StreamEvent,
|
StreamEvent,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
Usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -255,6 +256,16 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
return StreamEvent(type="tool_use", data=tool_data), None
|
return StreamEvent(type="tool_use", data=tool_data), None
|
||||||
|
|
||||||
elif event_type == "message_delta":
|
elif event_type == "message_delta":
|
||||||
|
# Handle token usage information
|
||||||
|
if usage := getattr(event, "usage", None):
|
||||||
|
self.log_usage(
|
||||||
|
Usage(
|
||||||
|
input_tokens=usage.input_tokens,
|
||||||
|
output_tokens=usage.output_tokens,
|
||||||
|
total_tokens=usage.total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
delta = getattr(event, "delta", None)
|
delta = getattr(event, "delta", None)
|
||||||
if delta:
|
if delta:
|
||||||
stop_reason = getattr(delta, "stop_reason", None)
|
stop_reason = getattr(delta, "stop_reason", None)
|
||||||
@ -263,22 +274,6 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
type="error", data="Max tokens reached"
|
type="error", data="Max tokens reached"
|
||||||
), current_tool_use
|
), current_tool_use
|
||||||
|
|
||||||
# Handle token usage information
|
|
||||||
usage = getattr(event, "usage", None)
|
|
||||||
if usage:
|
|
||||||
usage_data = {
|
|
||||||
"input_tokens": getattr(usage, "input_tokens", 0),
|
|
||||||
"output_tokens": getattr(usage, "output_tokens", 0),
|
|
||||||
"cache_creation_input_tokens": getattr(
|
|
||||||
usage, "cache_creation_input_tokens", None
|
|
||||||
),
|
|
||||||
"cache_read_input_tokens": getattr(
|
|
||||||
usage, "cache_read_input_tokens", None
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# Could emit this as a separate event type if needed
|
|
||||||
logger.debug(f"Token usage: {usage_data}")
|
|
||||||
|
|
||||||
return None, current_tool_use
|
return None, current_tool_use
|
||||||
|
|
||||||
elif event_type == "message_stop":
|
elif event_type == "message_stop":
|
||||||
|
|||||||
@ -25,6 +25,15 @@ class MessageRole(str, Enum):
|
|||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Usage:
|
||||||
|
"""Usage data for an LLM call."""
|
||||||
|
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextContent:
|
class TextContent:
|
||||||
"""Text content in a message."""
|
"""Text content in a message."""
|
||||||
@ -219,6 +228,11 @@ class BaseLLMProvider(ABC):
|
|||||||
self._client = self._initialize_client()
|
self._client = self._initialize_client()
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
def log_usage(self, usage: Usage):
|
||||||
|
"""Log usage data."""
|
||||||
|
logger.debug(f"Token usage: {usage.to_dict()}")
|
||||||
|
print(f"Token usage: {usage.to_dict()}")
|
||||||
|
|
||||||
def execute_tool(
|
def execute_tool(
|
||||||
self,
|
self,
|
||||||
tool_call: ToolCall,
|
tool_call: ToolCall,
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from memory.common.llms.base import (
|
|||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResultContent,
|
ToolResultContent,
|
||||||
ToolUseContent,
|
ToolUseContent,
|
||||||
|
Usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -283,6 +284,17 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
"""
|
"""
|
||||||
events: list[StreamEvent] = []
|
events: list[StreamEvent] = []
|
||||||
|
|
||||||
|
# Handle usage information (comes in final chunk with empty choices)
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
self.log_usage(
|
||||||
|
Usage(
|
||||||
|
input_tokens=usage.prompt_tokens,
|
||||||
|
output_tokens=usage.completion_tokens,
|
||||||
|
total_tokens=usage.total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not chunk.choices:
|
if not chunk.choices:
|
||||||
return events, current_tool_call
|
return events, current_tool_call
|
||||||
|
|
||||||
@ -337,6 +349,14 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(**kwargs)
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
|
usage = response.usage
|
||||||
|
self.log_usage(
|
||||||
|
Usage(
|
||||||
|
input_tokens=usage.prompt_tokens,
|
||||||
|
output_tokens=usage.completion_tokens,
|
||||||
|
total_tokens=usage.total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
return response.choices[0].message.content or ""
|
return response.choices[0].message.content or ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OpenAI API error: {e}")
|
logger.error(f"OpenAI API error: {e}")
|
||||||
@ -355,6 +375,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages, system_prompt, tools, settings, stream=True
|
messages, system_prompt, tools, settings, stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kwargs.get("stream"):
|
||||||
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = self.client.chat.completions.create(**kwargs)
|
stream = self.client.chat.completions.create(**kwargs)
|
||||||
current_tool_call: dict[str, Any] | None = None
|
current_tool_call: dict[str, Any] | None = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user