diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index 938d552..0ef8691 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -14,6 +14,7 @@ from memory.common.llms.base import ( MessageRole, StreamEvent, ToolDefinition, + Usage, ) logger = logging.getLogger(__name__) @@ -255,6 +256,16 @@ class AnthropicProvider(BaseLLMProvider): return StreamEvent(type="tool_use", data=tool_data), None elif event_type == "message_delta": + # Handle token usage information + if usage := getattr(event, "usage", None): + self.log_usage( + Usage( + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + total_tokens=usage.total_tokens, + ) + ) + delta = getattr(event, "delta", None) if delta: stop_reason = getattr(delta, "stop_reason", None) @@ -263,22 +274,6 @@ class AnthropicProvider(BaseLLMProvider): type="error", data="Max tokens reached" ), current_tool_use - # Handle token usage information - usage = getattr(event, "usage", None) - if usage: - usage_data = { - "input_tokens": getattr(usage, "input_tokens", 0), - "output_tokens": getattr(usage, "output_tokens", 0), - "cache_creation_input_tokens": getattr( - usage, "cache_creation_input_tokens", None - ), - "cache_read_input_tokens": getattr( - usage, "cache_read_input_tokens", None - ), - } - # Could emit this as a separate event type if needed - logger.debug(f"Token usage: {usage_data}") - return None, current_tool_use elif event_type == "message_stop": diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index 6238daa..f62eb4d 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -25,6 +25,15 @@ class MessageRole(str, Enum): TOOL = "tool" +@dataclass +class Usage: + """Usage data for an LLM call.""" + + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + + @dataclass class TextContent: """Text content in a message.""" @@ -219,6 +228,11 @@ class BaseLLMProvider(ABC): self._client = self._initialize_client() return self._client + def log_usage(self, usage: Usage): + """Log usage data.""" + logger.debug(f"Token usage: {usage.to_dict()}") + print(f"Token usage: {usage.to_dict()}") + def execute_tool( self, tool_call: ToolCall, diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py index 594b459..811e0d0 100644 --- a/src/memory/common/llms/openai_provider.py +++ b/src/memory/common/llms/openai_provider.py @@ -16,6 +16,7 @@ from memory.common.llms.base import ( ToolDefinition, ToolResultContent, ToolUseContent, + Usage, ) logger = logging.getLogger(__name__) @@ -283,6 +284,17 @@ class OpenAIProvider(BaseLLMProvider): """ events: list[StreamEvent] = [] + # Handle usage information (comes in final chunk with empty choices) + if hasattr(chunk, "usage") and chunk.usage: + usage = chunk.usage + self.log_usage( + Usage( + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + ) + ) + if not chunk.choices: return events, current_tool_call @@ -337,6 +349,14 @@ class OpenAIProvider(BaseLLMProvider): try: response = self.client.chat.completions.create(**kwargs) + usage = response.usage + self.log_usage( + Usage( + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + ) + ) return response.choices[0].message.content or "" except Exception as e: logger.error(f"OpenAI API error: {e}") @@ -355,6 +375,9 @@ class OpenAIProvider(BaseLLMProvider): messages, system_prompt, tools, settings, stream=True ) + if kwargs.get("stream"): + kwargs["stream_options"] = {"include_usage": True} + try: stream = self.client.chat.completions.create(**kwargs) current_tool_call: dict[str, Any] | None = None