handle openai

This commit is contained in:
Daniel O'Connell 2025-10-13 11:59:23 +02:00
parent 99d3843f47
commit e68671deb4
5 changed files with 417 additions and 331 deletions

View File

@ -6,7 +6,7 @@ dotenv==0.9.9
voyageai==0.3.2 voyageai==0.3.2
qdrant-client==1.9.0 qdrant-client==1.9.0
anthropic==0.69.0 anthropic==0.69.0
openai==1.25.0 openai==2.3.0
# Pin the httpx version, as newer versions break the anthropic client # Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0 httpx==0.27.0
celery[sqs]==5.3.6 celery[sqs]==5.3.6

View File

@ -22,10 +22,14 @@ from memory.common.llms.base import (
ToolUseContent, ToolUseContent,
create_provider, create_provider,
) )
from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.openai_provider import OpenAIProvider
from memory.common import tokens from memory.common import tokens
__all__ = [ __all__ = [
"BaseLLMProvider", "BaseLLMProvider",
"AnthropicProvider",
"OpenAIProvider",
"Message", "Message",
"MessageRole", "MessageRole",
"MessageContent", "MessageContent",

View File

@ -2,7 +2,7 @@
import json import json
import logging import logging
from typing import Any, AsyncIterator, Iterator, Optional from typing import Any, AsyncIterator, Iterator
import anthropic import anthropic
@ -14,9 +14,6 @@ from memory.common.llms.base import (
MessageRole, MessageRole,
StreamEvent, StreamEvent,
ToolDefinition, ToolDefinition,
ToolUseContent,
ThinkingContent,
TextContent,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +42,7 @@ class AnthropicProvider(BaseLLMProvider):
""" """
super().__init__(api_key, model) super().__init__(api_key, model)
self.enable_thinking = enable_thinking self.enable_thinking = enable_thinking
self._async_client: Optional[anthropic.AsyncAnthropic] = None self._async_client: anthropic.AsyncAnthropic | None = None
def _initialize_client(self) -> anthropic.Anthropic: def _initialize_client(self) -> anthropic.Anthropic:
"""Initialize the Anthropic client.""" """Initialize the Anthropic client."""
@ -336,6 +333,7 @@ class AnthropicProvider(BaseLLMProvider):
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
print(kwargs)
try: try:
with self.client.messages.stream(**kwargs) as stream: with self.client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None current_tool_use: dict[str, Any] | None = None
@ -396,56 +394,3 @@ class AnthropicProvider(BaseLLMProvider):
except Exception as e: except Exception as e:
logger.error(f"Anthropic streaming error: {e}") logger.error(f"Anthropic streaming error: {e}")
yield StreamEvent(type="error", data=str(e)) yield StreamEvent(type="error", data=str(e))
def stream_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Iterator[StreamEvent]:
if max_iterations <= 0:
return
response = TextContent(text="")
thinking = ThinkingContent(thinking="", signature="")
for event in self.stream(
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
settings=settings,
):
if event.type == "text":
response.text += event.data
yield event
elif event.type == "thinking" and event.signature:
thinking.signature = event.signature
elif event.type == "thinking":
thinking.thinking += event.data
yield event
elif event.type == "tool_use":
yield event
tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
messages.append(Message.user(tool_result=tool_result))
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
)
elif event.type == "tool_result":
yield event
elif event.type == "error":
logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}")

View File

@ -6,7 +6,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
from PIL import Image from PIL import Image
@ -36,6 +36,10 @@ class TextContent:
"""Convert to dictionary format.""" """Convert to dictionary format."""
return {"type": "text", "text": self.text} return {"type": "text", "text": self.text}
@property
def valid(self):
return self.text
@dataclass @dataclass
class ImageContent: class ImageContent:
@ -43,13 +47,17 @@ class ImageContent:
type: Literal["image"] = "image" type: Literal["image"] = "image"
image: Image.Image = None # type: ignore image: Image.Image = None # type: ignore
detail: Optional[str] = None # For OpenAI: "low", "high", "auto" detail: str | None = None # For OpenAI: "low", "high", "auto"
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format.""" """Convert to dictionary format."""
# Note: Image will be encoded by provider-specific implementation # Note: Image will be encoded by provider-specific implementation
return {"type": "image", "image": self.image} return {"type": "image", "image": self.image}
@property
def valid(self):
return self.image
@dataclass @dataclass
class ToolUseContent: class ToolUseContent:
@ -69,6 +77,10 @@ class ToolUseContent:
"input": self.input, "input": self.input,
} }
@property
def valid(self):
return self.id and self.name
@dataclass @dataclass
class ToolResultContent: class ToolResultContent:
@ -88,6 +100,10 @@ class ToolResultContent:
"is_error": self.is_error, "is_error": self.is_error,
} }
@property
def valid(self):
return self.tool_use_id
@dataclass @dataclass
class ThinkingContent: class ThinkingContent:
@ -105,6 +121,10 @@ class ThinkingContent:
"signature": self.signature, "signature": self.signature,
} }
@property
def valid(self):
return self.thinking and self.signature
MessageContent = Union[ MessageContent = Union[
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
@ -135,18 +155,8 @@ class Message:
return {"role": self.role.value, "content": content_list} return {"role": self.role.value, "content": content_list}
@staticmethod @staticmethod
def assistant( def assistant(*content: MessageContent) -> "Message":
text: TextContent | None = None, parts = [c for c in content if c.valid]
thinking: ThinkingContent | None = None,
tool_use: ToolUseContent | None = None,
) -> "Message":
parts = []
if text:
parts.append(text)
if thinking:
parts.append(thinking)
if tool_use:
parts.append(tool_use)
return Message(role=MessageRole.ASSISTANT, content=parts) return Message(role=MessageRole.ASSISTANT, content=parts)
@staticmethod @staticmethod
@ -295,13 +305,22 @@ class BaseLLMProvider(ABC):
"""Convert ThinkingContent to provider format. Override for custom format.""" """Convert ThinkingContent to provider format. Override for custom format."""
return content.to_dict() return content.to_dict()
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]: def _convert_message_content(
self, content: str | MessageContent | list[MessageContent]
) -> dict[str, Any] | list[dict[str, Any]]:
""" """
Convert a MessageContent item to provider format. Convert a MessageContent item to provider format.
Dispatches to type-specific converters that can be overridden. Dispatches to type-specific converters that can be overridden.
""" """
if isinstance(content, TextContent): if isinstance(content, str):
return self._convert_text_content(TextContent(text=content))
elif isinstance(content, list):
return [
cast(dict[str, Any], self._convert_message_content(item))
for item in content
]
elif isinstance(content, TextContent):
return self._convert_text_content(content) return self._convert_text_content(content)
elif isinstance(content, ImageContent): elif isinstance(content, ImageContent):
return self._convert_image_content(content) return self._convert_image_content(content)
@ -318,9 +337,16 @@ class BaseLLMProvider(ABC):
""" """
Convert a Message to provider format. Convert a Message to provider format.
Can be overridden for provider-specific handling (e.g., filtering system messages). Handles both string content and list[MessageContent], using provider-specific
content converters for each content item.
Can be overridden for provider-specific handling (e.g., OpenAI's tool results).
""" """
return message.to_dict() # Handle simple string content
return {
"role": message.role.value,
"content": self._convert_message_content(message.content),
}
def _should_include_message(self, message: Message) -> bool: def _should_include_message(self, message: Message) -> bool:
""" """
@ -362,7 +388,7 @@ class BaseLLMProvider(ABC):
def _convert_tools( def _convert_tools(
self, tools: list[ToolDefinition] | None self, tools: list[ToolDefinition] | None
) -> Optional[list[dict[str, Any]]]: ) -> list[dict[str, Any]] | None:
"""Convert tool definitions to provider format.""" """Convert tool definitions to provider format."""
if not tools: if not tools:
return None return None
@ -456,7 +482,6 @@ class BaseLLMProvider(ABC):
""" """
pass pass
@abstractmethod
def stream_with_tools( def stream_with_tools(
self, self,
messages: list[Message], messages: list[Message],
@ -465,7 +490,75 @@ class BaseLLMProvider(ABC):
system_prompt: str | None = None, system_prompt: str | None = None,
max_iterations: int = 10, max_iterations: int = 10,
) -> Iterator[StreamEvent]: ) -> Iterator[StreamEvent]:
pass """
Stream response with automatic tool execution.
This method handles the tool call loop automatically, executing tools
and sending results back to the LLM until it produces a final response
or max_iterations is reached.
Args:
messages: Conversation history
tools: Dict mapping tool names to ToolDefinition handlers
settings: Optional settings for the generation
system_prompt: Optional system prompt
max_iterations: Maximum number of tool call iterations
Yields:
StreamEvent objects for text, tool calls, and tool results
"""
if max_iterations <= 0:
return
response = TextContent(text="")
thinking = ThinkingContent(thinking="")
for event in self.stream(
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
settings=settings,
):
if event.type == "text":
response.text += event.data
yield event
elif event.type == "thinking":
thinking.thinking += event.data
yield event
elif event.type == "tool_use":
yield event
# Execute the tool and yield the result
tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
# Add assistant message with tool call
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
# Add user message with tool result
messages.append(Message.user(tool_result=tool_result))
# Recursively continue the conversation with reduced iterations
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
)
return # Exit after recursive call completes
elif event.type == "error":
logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}")
elif event.type == "done":
# Stream completed without tool calls
yield event
def run_with_tools( def run_with_tools(
self, self,
@ -550,12 +643,22 @@ def create_provider(
api_key=api_key, model=model, enable_thinking=enable_thinking api_key=api_key, model=model, enable_thinking=enable_thinking
) )
# Could add OpenAI support here in the future elif provider == "openai":
# elif "gpt" in model_lower or model.startswith("openai"): # OpenAI models
# ... if api_key is None:
api_key = settings.OPENAI_API_KEY
if not api_key:
raise ValueError(
"OPENAI_API_KEY not found in settings. Please set it in your .env file."
)
from memory.common.llms.openai_provider import OpenAIProvider
return OpenAIProvider(api_key=api_key, model=model)
else: else:
raise ValueError( raise ValueError(
f"Unknown provider for model: {model}. " f"Unknown provider for model: {model}. "
f"Supported providers: Anthropic (claude-*)" f"Supported providers: Anthropic (anthropic/*), OpenAI (openai/*)"
) )

View File

@ -1,7 +1,8 @@
"""OpenAI LLM provider implementation.""" """OpenAI LLM provider implementation."""
import json
import logging import logging
from typing import Any, AsyncIterator, Iterator, Optional from typing import Any, AsyncIterator, Iterator
import openai import openai
@ -10,11 +11,8 @@ from memory.common.llms.base import (
ImageContent, ImageContent,
LLMSettings, LLMSettings,
Message, Message,
MessageContent,
MessageRole,
StreamEvent, StreamEvent,
TextContent, TextContent,
ThinkingContent,
ToolDefinition, ToolDefinition,
ToolResultContent, ToolResultContent,
ToolUseContent, ToolUseContent,
@ -26,92 +24,132 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(BaseLLMProvider): class OpenAIProvider(BaseLLMProvider):
"""OpenAI LLM provider with streaming and tool support.""" """OpenAI LLM provider with streaming and tool support."""
# Models that use max_completion_tokens instead of max_tokens
# These are reasoning models with different parameter requirements
NON_REASONING_MODELS = {"gpt-4o"}
def __init__(self, api_key: str, model: str):
"""
Initialize the OpenAI provider.
Args:
api_key: OpenAI API key
model: Model identifier
"""
super().__init__(api_key, model)
self._async_client: openai.AsyncOpenAI | None = None
def _is_reasoning_model(self) -> bool:
"""
Check if the current model is a reasoning model (o1 series).
Reasoning models have different parameter requirements:
- Use max_completion_tokens instead of max_tokens
- Don't support temperature (always uses temperature=1)
- Don't support top_p
- Don't support system messages via system parameter
"""
return self.model.lower() not in self.NON_REASONING_MODELS
def _initialize_client(self) -> openai.OpenAI: def _initialize_client(self) -> openai.OpenAI:
"""Initialize the OpenAI client.""" """Initialize the OpenAI client."""
return openai.OpenAI(api_key=self.api_key) return openai.OpenAI(api_key=self.api_key)
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]: @property
""" def async_client(self) -> openai.AsyncOpenAI:
Convert our Message format to OpenAI format. """Lazy-load the async client."""
if self._async_client is None:
self._async_client = openai.AsyncOpenAI(api_key=self.api_key)
return self._async_client
Args: def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
messages: List of messages in our format """Convert TextContent to OpenAI format."""
return {"type": "text", "text": content.text}
Returns: def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
List of messages in OpenAI format """Convert ImageContent to OpenAI image_url format."""
""" encoded_image = self.encode_image(content.image)
openai_messages = []
for msg in messages:
if isinstance(msg.content, str):
openai_messages.append({"role": msg.role.value, "content": msg.content})
else:
# Handle multi-part content
content_parts = []
for item in msg.content:
if isinstance(item, TextContent):
content_parts.append({"type": "text", "text": item.text})
elif isinstance(item, ImageContent):
encoded_image = self.encode_image(item.image)
image_part: dict[str, Any] = { image_part: dict[str, Any] = {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
"url": f"data:image/jpeg;base64,{encoded_image}" }
if content.detail:
image_part["image_url"]["detail"] = content.detail
return image_part
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
"""Convert ToolUseContent to provider format. Override for custom format."""
return {
"id": content.id,
"type": "function",
"function": {
"name": content.name,
"arguments": json.dumps(content.input),
}, },
} }
if item.detail:
image_part["image_url"]["detail"] = item.detail def _convert_tool_result_content(
content_parts.append(image_part) self, content: ToolResultContent
elif isinstance(item, ToolUseContent): ) -> dict[str, Any]:
# OpenAI doesn't have tool_use in content, it's a separate field """Convert ToolResultContent to provider format. Override for custom format."""
# We'll handle this by adding a tool_calls field to the message return {
pass
elif isinstance(item, ToolResultContent):
# OpenAI handles tool results as separate "tool" role messages
openai_messages.append(
{
"role": "tool", "role": "tool",
"tool_call_id": item.tool_use_id, "tool_call_id": content.tool_use_id,
"content": item.content, "content": content.content,
} }
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""
Convert messages to OpenAI format.
OpenAI has special requirements:
- ToolResultContent creates separate "tool" role messages
- ToolUseContent becomes tool_calls field on assistant messages
- One input Message can produce multiple output messages
Returns:
Flat list of OpenAI-formatted message dicts
"""
openai_messages: list[dict[str, Any]] = []
for message in messages:
# Handle simple string content
if isinstance(message.content, str):
openai_messages.append(
{"role": message.role.value, "content": message.content}
) )
continue continue
elif isinstance(item, ThinkingContent):
# OpenAI doesn't have native thinking support in most models
# We can add it as text with a special marker
content_parts.append(
{
"type": "text",
"text": f"[Thinking: {item.thinking}]",
}
)
# Check if this message has tool calls # Handle multi-part content
tool_calls = [ content_parts: list[dict[str, Any]] = []
item for item in msg.content if isinstance(item, ToolUseContent) tool_calls_list: list[dict[str, Any]] = []
]
message_dict: dict[str, Any] = {"role": msg.role.value} for item in message.content:
if isinstance(item, ToolResultContent):
openai_messages.append(self._convert_tool_result_content(item))
elif isinstance(item, ToolUseContent):
tool_calls_list.append(self._convert_tool_use_content(item))
else:
content_parts.append(self._convert_message_content(item))
if content_parts or tool_calls_list:
msg_dict: dict[str, Any] = {"role": message.role.value}
if content_parts: if content_parts:
message_dict["content"] = content_parts msg_dict["content"] = content_parts
elif tool_calls_list:
# Assistant messages with tool calls need content field (use empty string)
msg_dict["content"] = ""
if tool_calls: if tool_calls_list:
message_dict["tool_calls"] = [ msg_dict["tool_calls"] = tool_calls_list
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": str(tc.input)},
}
for tc in tool_calls
]
openai_messages.append(message_dict) openai_messages.append(msg_dict)
return openai_messages return openai_messages
def _convert_tools( def _convert_tools(
self, tools: Optional[list[ToolDefinition]] self, tools: list[ToolDefinition] | None
) -> Optional[list[dict[str, Any]]]: ) -> Optional[list[dict[str, Any]]]:
""" """
Convert our tool definitions to OpenAI format. Convert our tool definitions to OpenAI format.
@ -137,6 +175,153 @@ class OpenAIProvider(BaseLLMProvider):
for tool in tools for tool in tools
] ]
def _build_request_kwargs(
self,
messages: list[Message],
system_prompt: str | None,
tools: Optional[list[ToolDefinition]],
settings: LLMSettings,
stream: bool = False,
) -> dict[str, Any]:
"""
Build common request kwargs for API calls.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools
settings: LLM settings
stream: Whether to enable streaming
Returns:
Dictionary of kwargs for OpenAI API call
"""
openai_messages = self._convert_messages(messages)
is_reasoning = self._is_reasoning_model()
# Log info for reasoning models on first use
if is_reasoning:
logger.debug(
f"Using reasoning model {self.model}: "
"max_completion_tokens will be used, temperature/top_p ignored"
)
# Reasoning models (o1) don't support system parameter
# System message must be added as a developer message instead
if system_prompt:
if is_reasoning:
# For o1 models, add system prompt as a developer message
openai_messages.insert(
0, {"role": "developer", "content": system_prompt}
)
else:
# For other models, add as system message
openai_messages.insert(0, {"role": "system", "content": system_prompt})
# Reasoning models use max_completion_tokens instead of max_tokens
max_tokens_key = "max_completion_tokens" if is_reasoning else "max_tokens"
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
max_tokens_key: settings.max_tokens,
}
# Reasoning models don't support temperature or top_p
if not is_reasoning:
kwargs["temperature"] = settings.temperature
kwargs["top_p"] = settings.top_p
if stream:
kwargs["stream"] = True
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
return kwargs
def _parse_and_finalize_tool_call(
self, tool_call: dict[str, Any]
) -> dict[str, Any]:
"""
Parse the accumulated tool call arguments and prepare for yielding.
Args:
tool_call: Tool call dict with 'arguments' field (JSON string)
Returns:
Tool call dict with parsed 'input' field (dict)
"""
try:
tool_call["input"] = json.loads(tool_call["arguments"])
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse tool arguments '{tool_call['arguments']}': {e}"
)
tool_call["input"] = {}
del tool_call["arguments"]
return tool_call
def _handle_stream_chunk(
self,
chunk: Any,
current_tool_call: dict[str, Any] | None,
) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]:
"""
Handle a single streaming chunk and return events and updated tool state.
Args:
chunk: Streaming chunk from OpenAI
current_tool_call: Current tool call being accumulated (or None)
Returns:
Tuple of (list of StreamEvents to yield, updated current_tool_call)
"""
events: list[StreamEvent] = []
if not chunk.choices:
return events, current_tool_call
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
events.append(StreamEvent(type="text", data=delta.content))
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one with parsed input
finalized = self._parse_and_finalize_tool_call(
current_tool_call
)
events.append(StreamEvent(type="tool_use", data=finalized))
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += tool_call.function.arguments
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
# Parse the final tool call arguments
finalized = self._parse_and_finalize_tool_call(current_tool_call)
events.append(StreamEvent(type="tool_use", data=finalized))
current_tool_call = None
return events, current_tool_call
def generate( def generate(
self, self,
messages: list[Message], messages: list[Message],
@ -146,30 +331,10 @@ class OpenAIProvider(BaseLLMProvider):
) -> str: ) -> str:
"""Generate a non-streaming response.""" """Generate a non-streaming response."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
openai_messages = self._convert_messages(messages) messages, system_prompt, tools, settings, stream=False
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
) )
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try: try:
response = self.client.chat.completions.create(**kwargs) response = self.client.chat.completions.create(**kwargs)
return response.choices[0].message.content or "" return response.choices[0].message.content or ""
@ -180,78 +345,25 @@ class OpenAIProvider(BaseLLMProvider):
def stream( def stream(
self, self,
messages: list[Message], messages: list[Message],
system_prompt: Optional[str] = None, system_prompt: str | None = None,
tools: Optional[list[ToolDefinition]] = None, tools: list[ToolDefinition] | None = None,
settings: Optional[LLMSettings] = None, settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]: ) -> Iterator[StreamEvent]:
"""Generate a streaming response.""" """Generate a streaming response."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
openai_messages = self._convert_messages(messages) messages, system_prompt, tools, settings, stream=True
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
) )
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try: try:
stream = self.client.chat.completions.create(**kwargs) stream = self.client.chat.completions.create(**kwargs)
current_tool_call: dict[str, Any] | None = None
current_tool_call: Optional[dict[str, Any]] = None
for chunk in stream: for chunk in stream:
if not chunk.choices: events, current_tool_call = self._handle_stream_chunk(
continue chunk, current_tool_call
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
) )
current_tool_call = { yield from events
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
yield StreamEvent(type="done") yield StreamEvent(type="done")
@ -268,35 +380,12 @@ class OpenAIProvider(BaseLLMProvider):
) -> str: ) -> str:
"""Generate a non-streaming response asynchronously.""" """Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
# Use async client messages, system_prompt, tools, settings, stream=False
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
) )
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try: try:
response = await async_client.chat.completions.create(**kwargs) response = await self.async_client.chat.completions.create(**kwargs)
return response.choices[0].message.content or "" return response.choices[0].message.content or ""
except Exception as e: except Exception as e:
logger.error(f"OpenAI API error: {e}") logger.error(f"OpenAI API error: {e}")
@ -311,75 +400,20 @@ class OpenAIProvider(BaseLLMProvider):
) -> AsyncIterator[StreamEvent]: ) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously.""" """Generate a streaming response asynchronously."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
# Use async client messages, system_prompt, tools, settings, stream=True
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
) )
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try: try:
stream = await async_client.chat.completions.create(**kwargs) stream = await self.async_client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None current_tool_call: Optional[dict[str, Any]] = None
async for chunk in stream: async for chunk in stream:
if not chunk.choices: events, current_tool_call = self._handle_stream_chunk(
continue chunk, current_tool_call
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
) )
current_tool_call = { for event in events:
"id": tool_call.id, yield event
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
yield StreamEvent(type="done") yield StreamEvent(type="done")