handle openai

This commit is contained in:
Daniel O'Connell 2025-10-13 11:59:23 +02:00
parent 99d3843f47
commit e68671deb4
5 changed files with 417 additions and 331 deletions

View File

@ -6,7 +6,7 @@ dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.69.0
openai==1.25.0
openai==2.3.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0
celery[sqs]==5.3.6

View File

@ -22,10 +22,14 @@ from memory.common.llms.base import (
ToolUseContent,
create_provider,
)
from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.openai_provider import OpenAIProvider
from memory.common import tokens
__all__ = [
"BaseLLMProvider",
"AnthropicProvider",
"OpenAIProvider",
"Message",
"MessageRole",
"MessageContent",

View File

@ -2,7 +2,7 @@
import json
import logging
from typing import Any, AsyncIterator, Iterator, Optional
from typing import Any, AsyncIterator, Iterator
import anthropic
@ -14,9 +14,6 @@ from memory.common.llms.base import (
MessageRole,
StreamEvent,
ToolDefinition,
ToolUseContent,
ThinkingContent,
TextContent,
)
logger = logging.getLogger(__name__)
@ -45,7 +42,7 @@ class AnthropicProvider(BaseLLMProvider):
"""
super().__init__(api_key, model)
self.enable_thinking = enable_thinking
self._async_client: Optional[anthropic.AsyncAnthropic] = None
self._async_client: anthropic.AsyncAnthropic | None = None
def _initialize_client(self) -> anthropic.Anthropic:
"""Initialize the Anthropic client."""
@ -336,6 +333,7 @@ class AnthropicProvider(BaseLLMProvider):
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
print(kwargs)
try:
with self.client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None
@ -396,56 +394,3 @@ class AnthropicProvider(BaseLLMProvider):
except Exception as e:
logger.error(f"Anthropic streaming error: {e}")
yield StreamEvent(type="error", data=str(e))
def stream_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Iterator[StreamEvent]:
if max_iterations <= 0:
return
response = TextContent(text="")
thinking = ThinkingContent(thinking="", signature="")
for event in self.stream(
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
settings=settings,
):
if event.type == "text":
response.text += event.data
yield event
elif event.type == "thinking" and event.signature:
thinking.signature = event.signature
elif event.type == "thinking":
thinking.thinking += event.data
yield event
elif event.type == "tool_use":
yield event
tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
messages.append(Message.user(tool_result=tool_result))
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
)
elif event.type == "tool_result":
yield event
elif event.type == "error":
logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}")

View File

@ -6,7 +6,7 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
from PIL import Image
@ -36,6 +36,10 @@ class TextContent:
"""Convert to dictionary format."""
return {"type": "text", "text": self.text}
@property
def valid(self):
return self.text
@dataclass
class ImageContent:
@ -43,13 +47,17 @@ class ImageContent:
type: Literal["image"] = "image"
image: Image.Image = None # type: ignore
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
detail: str | None = None # For OpenAI: "low", "high", "auto"
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
# Note: Image will be encoded by provider-specific implementation
return {"type": "image", "image": self.image}
@property
def valid(self):
return self.image
@dataclass
class ToolUseContent:
@ -69,6 +77,10 @@ class ToolUseContent:
"input": self.input,
}
@property
def valid(self):
return self.id and self.name
@dataclass
class ToolResultContent:
@ -88,6 +100,10 @@ class ToolResultContent:
"is_error": self.is_error,
}
@property
def valid(self):
return self.tool_use_id
@dataclass
class ThinkingContent:
@ -105,6 +121,10 @@ class ThinkingContent:
"signature": self.signature,
}
@property
def valid(self):
return self.thinking and self.signature
MessageContent = Union[
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
@ -135,18 +155,8 @@ class Message:
return {"role": self.role.value, "content": content_list}
@staticmethod
def assistant(
text: TextContent | None = None,
thinking: ThinkingContent | None = None,
tool_use: ToolUseContent | None = None,
) -> "Message":
parts = []
if text:
parts.append(text)
if thinking:
parts.append(thinking)
if tool_use:
parts.append(tool_use)
def assistant(*content: MessageContent) -> "Message":
parts = [c for c in content if c.valid]
return Message(role=MessageRole.ASSISTANT, content=parts)
@staticmethod
@ -295,13 +305,22 @@ class BaseLLMProvider(ABC):
"""Convert ThinkingContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
def _convert_message_content(
self, content: str | MessageContent | list[MessageContent]
) -> dict[str, Any] | list[dict[str, Any]]:
"""
Convert a MessageContent item to provider format.
Dispatches to type-specific converters that can be overridden.
"""
if isinstance(content, TextContent):
if isinstance(content, str):
return self._convert_text_content(TextContent(text=content))
elif isinstance(content, list):
return [
cast(dict[str, Any], self._convert_message_content(item))
for item in content
]
elif isinstance(content, TextContent):
return self._convert_text_content(content)
elif isinstance(content, ImageContent):
return self._convert_image_content(content)
@ -318,9 +337,16 @@ class BaseLLMProvider(ABC):
"""
Convert a Message to provider format.
Can be overridden for provider-specific handling (e.g., filtering system messages).
Handles both string content and list[MessageContent], using provider-specific
content converters for each content item.
Can be overridden for provider-specific handling (e.g., OpenAI's tool results).
"""
return message.to_dict()
# Handle simple string content
return {
"role": message.role.value,
"content": self._convert_message_content(message.content),
}
def _should_include_message(self, message: Message) -> bool:
"""
@ -362,7 +388,7 @@ class BaseLLMProvider(ABC):
def _convert_tools(
self, tools: list[ToolDefinition] | None
) -> Optional[list[dict[str, Any]]]:
) -> list[dict[str, Any]] | None:
"""Convert tool definitions to provider format."""
if not tools:
return None
@ -456,7 +482,6 @@ class BaseLLMProvider(ABC):
"""
pass
@abstractmethod
def stream_with_tools(
self,
messages: list[Message],
@ -465,7 +490,75 @@ class BaseLLMProvider(ABC):
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Iterator[StreamEvent]:
pass
"""
Stream response with automatic tool execution.
This method handles the tool call loop automatically, executing tools
and sending results back to the LLM until it produces a final response
or max_iterations is reached.
Args:
messages: Conversation history
tools: Dict mapping tool names to ToolDefinition handlers
settings: Optional settings for the generation
system_prompt: Optional system prompt
max_iterations: Maximum number of tool call iterations
Yields:
StreamEvent objects for text, tool calls, and tool results
"""
if max_iterations <= 0:
return
response = TextContent(text="")
thinking = ThinkingContent(thinking="")
for event in self.stream(
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
settings=settings,
):
if event.type == "text":
response.text += event.data
yield event
elif event.type == "thinking":
thinking.thinking += event.data
yield event
elif event.type == "tool_use":
yield event
# Execute the tool and yield the result
tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
# Add assistant message with tool call
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
# Add user message with tool result
messages.append(Message.user(tool_result=tool_result))
# Recursively continue the conversation with reduced iterations
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
)
return # Exit after recursive call completes
elif event.type == "error":
logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}")
elif event.type == "done":
# Stream completed without tool calls
yield event
def run_with_tools(
self,
@ -550,12 +643,22 @@ def create_provider(
api_key=api_key, model=model, enable_thinking=enable_thinking
)
# Could add OpenAI support here in the future
# elif "gpt" in model_lower or model.startswith("openai"):
# ...
elif provider == "openai":
# OpenAI models
if api_key is None:
api_key = settings.OPENAI_API_KEY
if not api_key:
raise ValueError(
"OPENAI_API_KEY not found in settings. Please set it in your .env file."
)
from memory.common.llms.openai_provider import OpenAIProvider
return OpenAIProvider(api_key=api_key, model=model)
else:
raise ValueError(
f"Unknown provider for model: {model}. "
f"Supported providers: Anthropic (claude-*)"
f"Supported providers: Anthropic (anthropic/*), OpenAI (openai/*)"
)

View File

@ -1,7 +1,8 @@
"""OpenAI LLM provider implementation."""
import json
import logging
from typing import Any, AsyncIterator, Iterator, Optional
from typing import Any, AsyncIterator, Iterator
import openai
@ -10,11 +11,8 @@ from memory.common.llms.base import (
ImageContent,
LLMSettings,
Message,
MessageContent,
MessageRole,
StreamEvent,
TextContent,
ThinkingContent,
ToolDefinition,
ToolResultContent,
ToolUseContent,
@ -26,92 +24,132 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(BaseLLMProvider):
"""OpenAI LLM provider with streaming and tool support."""
# Models that use max_completion_tokens instead of max_tokens
# These are reasoning models with different parameter requirements
NON_REASONING_MODELS = {"gpt-4o"}
def __init__(self, api_key: str, model: str):
"""
Initialize the OpenAI provider.
Args:
api_key: OpenAI API key
model: Model identifier
"""
super().__init__(api_key, model)
self._async_client: openai.AsyncOpenAI | None = None
def _is_reasoning_model(self) -> bool:
"""
Check if the current model is a reasoning model (o1 series).
Reasoning models have different parameter requirements:
- Use max_completion_tokens instead of max_tokens
- Don't support temperature (always uses temperature=1)
- Don't support top_p
- Don't support system messages via system parameter
"""
return self.model.lower() not in self.NON_REASONING_MODELS
def _initialize_client(self) -> openai.OpenAI:
"""Initialize the OpenAI client."""
return openai.OpenAI(api_key=self.api_key)
@property
def async_client(self) -> openai.AsyncOpenAI:
"""Lazy-load the async client."""
if self._async_client is None:
self._async_client = openai.AsyncOpenAI(api_key=self.api_key)
return self._async_client
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
"""Convert TextContent to OpenAI format."""
return {"type": "text", "text": content.text}
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
"""Convert ImageContent to OpenAI image_url format."""
encoded_image = self.encode_image(content.image)
image_part: dict[str, Any] = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
}
if content.detail:
image_part["image_url"]["detail"] = content.detail
return image_part
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
"""Convert ToolUseContent to provider format. Override for custom format."""
return {
"id": content.id,
"type": "function",
"function": {
"name": content.name,
"arguments": json.dumps(content.input),
},
}
def _convert_tool_result_content(
self, content: ToolResultContent
) -> dict[str, Any]:
"""Convert ToolResultContent to provider format. Override for custom format."""
return {
"role": "tool",
"tool_call_id": content.tool_use_id,
"content": content.content,
}
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""
Convert our Message format to OpenAI format.
Convert messages to OpenAI format.
Args:
messages: List of messages in our format
OpenAI has special requirements:
- ToolResultContent creates separate "tool" role messages
- ToolUseContent becomes tool_calls field on assistant messages
- One input Message can produce multiple output messages
Returns:
List of messages in OpenAI format
Flat list of OpenAI-formatted message dicts
"""
openai_messages = []
openai_messages: list[dict[str, Any]] = []
for msg in messages:
if isinstance(msg.content, str):
openai_messages.append({"role": msg.role.value, "content": msg.content})
else:
# Handle multi-part content
content_parts = []
for item in msg.content:
if isinstance(item, TextContent):
content_parts.append({"type": "text", "text": item.text})
elif isinstance(item, ImageContent):
encoded_image = self.encode_image(item.image)
image_part: dict[str, Any] = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}"
},
}
if item.detail:
image_part["image_url"]["detail"] = item.detail
content_parts.append(image_part)
elif isinstance(item, ToolUseContent):
# OpenAI doesn't have tool_use in content, it's a separate field
# We'll handle this by adding a tool_calls field to the message
pass
elif isinstance(item, ToolResultContent):
# OpenAI handles tool results as separate "tool" role messages
openai_messages.append(
{
"role": "tool",
"tool_call_id": item.tool_use_id,
"content": item.content,
}
)
continue
elif isinstance(item, ThinkingContent):
# OpenAI doesn't have native thinking support in most models
# We can add it as text with a special marker
content_parts.append(
{
"type": "text",
"text": f"[Thinking: {item.thinking}]",
}
)
for message in messages:
# Handle simple string content
if isinstance(message.content, str):
openai_messages.append(
{"role": message.role.value, "content": message.content}
)
continue
# Check if this message has tool calls
tool_calls = [
item for item in msg.content if isinstance(item, ToolUseContent)
]
# Handle multi-part content
content_parts: list[dict[str, Any]] = []
tool_calls_list: list[dict[str, Any]] = []
message_dict: dict[str, Any] = {"role": msg.role.value}
for item in message.content:
if isinstance(item, ToolResultContent):
openai_messages.append(self._convert_tool_result_content(item))
elif isinstance(item, ToolUseContent):
tool_calls_list.append(self._convert_tool_use_content(item))
else:
content_parts.append(self._convert_message_content(item))
if content_parts or tool_calls_list:
msg_dict: dict[str, Any] = {"role": message.role.value}
if content_parts:
message_dict["content"] = content_parts
msg_dict["content"] = content_parts
elif tool_calls_list:
# Assistant messages with tool calls need content field (use empty string)
msg_dict["content"] = ""
if tool_calls:
message_dict["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": str(tc.input)},
}
for tc in tool_calls
]
if tool_calls_list:
msg_dict["tool_calls"] = tool_calls_list
openai_messages.append(message_dict)
openai_messages.append(msg_dict)
return openai_messages
def _convert_tools(
self, tools: Optional[list[ToolDefinition]]
self, tools: list[ToolDefinition] | None
) -> Optional[list[dict[str, Any]]]:
"""
Convert our tool definitions to OpenAI format.
@ -137,6 +175,153 @@ class OpenAIProvider(BaseLLMProvider):
for tool in tools
]
def _build_request_kwargs(
self,
messages: list[Message],
system_prompt: str | None,
tools: Optional[list[ToolDefinition]],
settings: LLMSettings,
stream: bool = False,
) -> dict[str, Any]:
"""
Build common request kwargs for API calls.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools
settings: LLM settings
stream: Whether to enable streaming
Returns:
Dictionary of kwargs for OpenAI API call
"""
openai_messages = self._convert_messages(messages)
is_reasoning = self._is_reasoning_model()
# Log info for reasoning models on first use
if is_reasoning:
logger.debug(
f"Using reasoning model {self.model}: "
"max_completion_tokens will be used, temperature/top_p ignored"
)
# Reasoning models (o1) don't support system parameter
# System message must be added as a developer message instead
if system_prompt:
if is_reasoning:
# For o1 models, add system prompt as a developer message
openai_messages.insert(
0, {"role": "developer", "content": system_prompt}
)
else:
# For other models, add as system message
openai_messages.insert(0, {"role": "system", "content": system_prompt})
# Reasoning models use max_completion_tokens instead of max_tokens
max_tokens_key = "max_completion_tokens" if is_reasoning else "max_tokens"
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
max_tokens_key: settings.max_tokens,
}
# Reasoning models don't support temperature or top_p
if not is_reasoning:
kwargs["temperature"] = settings.temperature
kwargs["top_p"] = settings.top_p
if stream:
kwargs["stream"] = True
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
return kwargs
def _parse_and_finalize_tool_call(
self, tool_call: dict[str, Any]
) -> dict[str, Any]:
"""
Parse the accumulated tool call arguments and prepare for yielding.
Args:
tool_call: Tool call dict with 'arguments' field (JSON string)
Returns:
Tool call dict with parsed 'input' field (dict)
"""
try:
tool_call["input"] = json.loads(tool_call["arguments"])
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse tool arguments '{tool_call['arguments']}': {e}"
)
tool_call["input"] = {}
del tool_call["arguments"]
return tool_call
def _handle_stream_chunk(
self,
chunk: Any,
current_tool_call: dict[str, Any] | None,
) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]:
"""
Handle a single streaming chunk and return events and updated tool state.
Args:
chunk: Streaming chunk from OpenAI
current_tool_call: Current tool call being accumulated (or None)
Returns:
Tuple of (list of StreamEvents to yield, updated current_tool_call)
"""
events: list[StreamEvent] = []
if not chunk.choices:
return events, current_tool_call
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
events.append(StreamEvent(type="text", data=delta.content))
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one with parsed input
finalized = self._parse_and_finalize_tool_call(
current_tool_call
)
events.append(StreamEvent(type="tool_use", data=finalized))
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += tool_call.function.arguments
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
# Parse the final tool call arguments
finalized = self._parse_and_finalize_tool_call(current_tool_call)
events.append(StreamEvent(type="tool_use", data=finalized))
current_tool_call = None
return events, current_tool_call
def generate(
self,
messages: list[Message],
@ -146,29 +331,9 @@ class OpenAIProvider(BaseLLMProvider):
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False
)
try:
response = self.client.chat.completions.create(**kwargs)
@ -180,78 +345,25 @@ class OpenAIProvider(BaseLLMProvider):
def stream(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True
)
try:
stream = self.client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None
current_tool_call: dict[str, Any] | None = None
for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
)
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
events, current_tool_call = self._handle_stream_chunk(
chunk, current_tool_call
)
yield from events
yield StreamEvent(type="done")
@ -268,35 +380,12 @@ class OpenAIProvider(BaseLLMProvider):
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
# Use async client
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False
)
try:
response = await async_client.chat.completions.create(**kwargs)
response = await self.async_client.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
@ -311,75 +400,20 @@ class OpenAIProvider(BaseLLMProvider):
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
# Use async client
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True
)
try:
stream = await async_client.chat.completions.create(**kwargs)
stream = await self.async_client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
)
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
events, current_tool_call = self._handle_stream_chunk(
chunk, current_tool_call
)
for event in events:
yield event
yield StreamEvent(type="done")