extract usage

This commit is contained in:
Daniel O'Connell 2025-11-01 17:56:20 +01:00
parent 07852f9ee7
commit c296f3b533
3 changed files with 48 additions and 16 deletions

View File

@ -14,6 +14,7 @@ from memory.common.llms.base import (
MessageRole, MessageRole,
StreamEvent, StreamEvent,
ToolDefinition, ToolDefinition,
Usage,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -255,6 +256,16 @@ class AnthropicProvider(BaseLLMProvider):
return StreamEvent(type="tool_use", data=tool_data), None return StreamEvent(type="tool_use", data=tool_data), None
elif event_type == "message_delta": elif event_type == "message_delta":
# Handle token usage information
if usage := getattr(event, "usage", None):
self.log_usage(
Usage(
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
total_tokens=usage.total_tokens,
)
)
delta = getattr(event, "delta", None) delta = getattr(event, "delta", None)
if delta: if delta:
stop_reason = getattr(delta, "stop_reason", None) stop_reason = getattr(delta, "stop_reason", None)
@ -263,22 +274,6 @@ class AnthropicProvider(BaseLLMProvider):
type="error", data="Max tokens reached" type="error", data="Max tokens reached"
), current_tool_use ), current_tool_use
# Handle token usage information
usage = getattr(event, "usage", None)
if usage:
usage_data = {
"input_tokens": getattr(usage, "input_tokens", 0),
"output_tokens": getattr(usage, "output_tokens", 0),
"cache_creation_input_tokens": getattr(
usage, "cache_creation_input_tokens", None
),
"cache_read_input_tokens": getattr(
usage, "cache_read_input_tokens", None
),
}
# Could emit this as a separate event type if needed
logger.debug(f"Token usage: {usage_data}")
return None, current_tool_use return None, current_tool_use
elif event_type == "message_stop": elif event_type == "message_stop":

View File

@ -25,6 +25,15 @@ class MessageRole(str, Enum):
TOOL = "tool" TOOL = "tool"
@dataclass
class Usage:
"""Usage data for an LLM call."""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
@dataclass @dataclass
class TextContent: class TextContent:
"""Text content in a message.""" """Text content in a message."""
@ -219,6 +228,11 @@ class BaseLLMProvider(ABC):
self._client = self._initialize_client() self._client = self._initialize_client()
return self._client return self._client
def log_usage(self, usage: Usage):
"""Log usage data."""
logger.debug(f"Token usage: {usage.to_dict()}")
print(f"Token usage: {usage.to_dict()}")
def execute_tool( def execute_tool(
self, self,
tool_call: ToolCall, tool_call: ToolCall,

View File

@ -16,6 +16,7 @@ from memory.common.llms.base import (
ToolDefinition, ToolDefinition,
ToolResultContent, ToolResultContent,
ToolUseContent, ToolUseContent,
Usage,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -283,6 +284,17 @@ class OpenAIProvider(BaseLLMProvider):
""" """
events: list[StreamEvent] = [] events: list[StreamEvent] = []
# Handle usage information (comes in final chunk with empty choices)
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage
self.log_usage(
Usage(
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
)
)
if not chunk.choices: if not chunk.choices:
return events, current_tool_call return events, current_tool_call
@ -337,6 +349,14 @@ class OpenAIProvider(BaseLLMProvider):
try: try:
response = self.client.chat.completions.create(**kwargs) response = self.client.chat.completions.create(**kwargs)
usage = response.usage
self.log_usage(
Usage(
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
)
)
return response.choices[0].message.content or "" return response.choices[0].message.content or ""
except Exception as e: except Exception as e:
logger.error(f"OpenAI API error: {e}") logger.error(f"OpenAI API error: {e}")
@ -355,6 +375,9 @@ class OpenAIProvider(BaseLLMProvider):
messages, system_prompt, tools, settings, stream=True messages, system_prompt, tools, settings, stream=True
) )
if kwargs.get("stream"):
kwargs["stream_options"] = {"include_usage": True}
try: try:
stream = self.client.chat.completions.create(**kwargs) stream = self.client.chat.completions.create(**kwargs)
current_tool_call: dict[str, Any] | None = None current_tool_call: dict[str, Any] | None = None