From e68671deb4a80a411c709007325b3dbeba94ccff Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 13 Oct 2025 11:59:23 +0200 Subject: [PATCH] handle openai --- requirements/requirements-common.txt | 2 +- src/memory/common/llms/__init__.py | 4 + src/memory/common/llms/anthropic_provider.py | 61 +-- src/memory/common/llms/base.py | 153 +++++- src/memory/common/llms/openai_provider.py | 528 ++++++++++--------- 5 files changed, 417 insertions(+), 331 deletions(-) diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt index f88aec2..a00f38e 100644 --- a/requirements/requirements-common.txt +++ b/requirements/requirements-common.txt @@ -6,7 +6,7 @@ dotenv==0.9.9 voyageai==0.3.2 qdrant-client==1.9.0 anthropic==0.69.0 -openai==1.25.0 +openai==2.3.0 # Pin the httpx version, as newer versions break the anthropic client httpx==0.27.0 celery[sqs]==5.3.6 diff --git a/src/memory/common/llms/__init__.py b/src/memory/common/llms/__init__.py index 68b6bd9..b2d37cc 100644 --- a/src/memory/common/llms/__init__.py +++ b/src/memory/common/llms/__init__.py @@ -22,10 +22,14 @@ from memory.common.llms.base import ( ToolUseContent, create_provider, ) +from memory.common.llms.anthropic_provider import AnthropicProvider +from memory.common.llms.openai_provider import OpenAIProvider from memory.common import tokens __all__ = [ "BaseLLMProvider", + "AnthropicProvider", + "OpenAIProvider", "Message", "MessageRole", "MessageContent", diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index a46236f..1969bfb 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, AsyncIterator, Iterator, Optional +from typing import Any, AsyncIterator, Iterator import anthropic @@ -14,9 +14,6 @@ from memory.common.llms.base import ( MessageRole, StreamEvent, ToolDefinition, - ToolUseContent, - ThinkingContent, - TextContent, ) logger = logging.getLogger(__name__) @@ -45,7 +42,7 @@ class AnthropicProvider(BaseLLMProvider): """ super().__init__(api_key, model) self.enable_thinking = enable_thinking - self._async_client: Optional[anthropic.AsyncAnthropic] = None + self._async_client: anthropic.AsyncAnthropic | None = None def _initialize_client(self) -> anthropic.Anthropic: """Initialize the Anthropic client.""" @@ -336,6 +333,7 @@ class AnthropicProvider(BaseLLMProvider): settings = settings or LLMSettings() kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + print(kwargs) try: with self.client.messages.stream(**kwargs) as stream: current_tool_use: dict[str, Any] | None = None @@ -396,56 +394,3 @@ class AnthropicProvider(BaseLLMProvider): except Exception as e: logger.error(f"Anthropic streaming error: {e}") yield StreamEvent(type="error", data=str(e)) - - def stream_with_tools( - self, - messages: list[Message], - tools: dict[str, ToolDefinition], - settings: LLMSettings | None = None, - system_prompt: str | None = None, - max_iterations: int = 10, - ) -> Iterator[StreamEvent]: - if max_iterations <= 0: - return - - response = TextContent(text="") - thinking = ThinkingContent(thinking="", signature="") - - for event in self.stream( - messages=messages, - system_prompt=system_prompt, - tools=list(tools.values()), - settings=settings, - ): - if event.type == "text": - response.text += event.data - yield event - elif event.type == "thinking" and event.signature: - thinking.signature = event.signature - elif event.type == "thinking": - thinking.thinking += event.data - yield event - elif event.type == "tool_use": - yield event - tool_result = self.execute_tool(event.data, tools) - yield StreamEvent(type="tool_result", data=tool_result.to_dict()) - messages.append( - Message.assistant( - response, - thinking, - ToolUseContent( - id=event.data["id"], - name=event.data["name"], - input=event.data["input"], - ), - ) - ) - messages.append(Message.user(tool_result=tool_result)) - yield from self.stream_with_tools( - messages, tools, settings, system_prompt, max_iterations - 1 - ) - elif event.type == "tool_result": - yield event - elif event.type == "error": - logger.error(f"LLM error: {event.data}") - raise RuntimeError(f"LLM error: {event.data}") diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index aceb7f8..113e71f 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union +from typing import Any, AsyncIterator, Iterator, Literal, Union, cast from PIL import Image @@ -36,6 +36,10 @@ class TextContent: """Convert to dictionary format.""" return {"type": "text", "text": self.text} + @property + def valid(self): + return self.text + @dataclass class ImageContent: @@ -43,13 +47,17 @@ class ImageContent: type: Literal["image"] = "image" image: Image.Image = None # type: ignore - detail: Optional[str] = None # For OpenAI: "low", "high", "auto" + detail: str | None = None # For OpenAI: "low", "high", "auto" def to_dict(self) -> dict[str, Any]: """Convert to dictionary format.""" # Note: Image will be encoded by provider-specific implementation return {"type": "image", "image": self.image} + @property + def valid(self): + return self.image + @dataclass class ToolUseContent: @@ -69,6 +77,10 @@ class ToolUseContent: "input": self.input, } + @property + def valid(self): + return self.id and self.name + @dataclass class ToolResultContent: @@ -88,6 +100,10 @@ class ToolResultContent: "is_error": self.is_error, } + @property + def valid(self): + return self.tool_use_id + @dataclass class ThinkingContent: @@ -105,6 +121,10 @@ class ThinkingContent: "signature": self.signature, } + @property + def valid(self): + return self.thinking and self.signature + MessageContent = Union[ TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent @@ -135,18 +155,8 @@ class Message: return {"role": self.role.value, "content": content_list} @staticmethod - def assistant( - text: TextContent | None = None, - thinking: ThinkingContent | None = None, - tool_use: ToolUseContent | None = None, - ) -> "Message": - parts = [] - if text: - parts.append(text) - if thinking: - parts.append(thinking) - if tool_use: - parts.append(tool_use) + def assistant(*content: MessageContent) -> "Message": + parts = [c for c in content if c.valid] return Message(role=MessageRole.ASSISTANT, content=parts) @staticmethod @@ -295,13 +305,22 @@ class BaseLLMProvider(ABC): """Convert ThinkingContent to provider format. Override for custom format.""" return content.to_dict() - def _convert_message_content(self, content: MessageContent) -> dict[str, Any]: + def _convert_message_content( + self, content: str | MessageContent | list[MessageContent] + ) -> dict[str, Any] | list[dict[str, Any]]: """ Convert a MessageContent item to provider format. Dispatches to type-specific converters that can be overridden. """ - if isinstance(content, TextContent): + if isinstance(content, str): + return self._convert_text_content(TextContent(text=content)) + elif isinstance(content, list): + return [ + cast(dict[str, Any], self._convert_message_content(item)) + for item in content + ] + elif isinstance(content, TextContent): return self._convert_text_content(content) elif isinstance(content, ImageContent): return self._convert_image_content(content) @@ -318,9 +337,16 @@ class BaseLLMProvider(ABC): """ Convert a Message to provider format. - Can be overridden for provider-specific handling (e.g., filtering system messages). + Handles both string content and list[MessageContent], using provider-specific + content converters for each content item. + + Can be overridden for provider-specific handling (e.g., OpenAI's tool results). """ - return message.to_dict() + # Handle simple string content + return { + "role": message.role.value, + "content": self._convert_message_content(message.content), + } def _should_include_message(self, message: Message) -> bool: """ @@ -362,7 +388,7 @@ class BaseLLMProvider(ABC): def _convert_tools( self, tools: list[ToolDefinition] | None - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: """Convert tool definitions to provider format.""" if not tools: return None @@ -456,7 +482,6 @@ class BaseLLMProvider(ABC): """ pass - @abstractmethod def stream_with_tools( self, messages: list[Message], @@ -465,7 +490,75 @@ class BaseLLMProvider(ABC): system_prompt: str | None = None, max_iterations: int = 10, ) -> Iterator[StreamEvent]: - pass + """ + Stream response with automatic tool execution. + + This method handles the tool call loop automatically, executing tools + and sending results back to the LLM until it produces a final response + or max_iterations is reached. + + Args: + messages: Conversation history + tools: Dict mapping tool names to ToolDefinition handlers + settings: Optional settings for the generation + system_prompt: Optional system prompt + max_iterations: Maximum number of tool call iterations + + Yields: + StreamEvent objects for text, tool calls, and tool results + """ + if max_iterations <= 0: + return + + response = TextContent(text="") + thinking = ThinkingContent(thinking="") + + for event in self.stream( + messages=messages, + system_prompt=system_prompt, + tools=list(tools.values()), + settings=settings, + ): + if event.type == "text": + response.text += event.data + yield event + elif event.type == "thinking": + thinking.thinking += event.data + yield event + elif event.type == "tool_use": + yield event + # Execute the tool and yield the result + tool_result = self.execute_tool(event.data, tools) + yield StreamEvent(type="tool_result", data=tool_result.to_dict()) + + # Add assistant message with tool call + messages.append( + Message.assistant( + response, + thinking, + ToolUseContent( + id=event.data["id"], + name=event.data["name"], + input=event.data["input"], + ), + ) + ) + + # Add user message with tool result + messages.append(Message.user(tool_result=tool_result)) + + # Recursively continue the conversation with reduced iterations + yield from self.stream_with_tools( + messages, tools, settings, system_prompt, max_iterations - 1 + ) + return # Exit after recursive call completes + + elif event.type == "error": + logger.error(f"LLM error: {event.data}") + raise RuntimeError(f"LLM error: {event.data}") + elif event.type == "done": + # Stream completed without tool calls + yield event def run_with_tools( self, @@ -550,12 +643,22 @@ def create_provider( api_key=api_key, model=model, enable_thinking=enable_thinking ) - # Could add OpenAI support here in the future - # elif "gpt" in model_lower or model.startswith("openai"): - # ... + elif provider == "openai": + # OpenAI models + if api_key is None: + api_key = settings.OPENAI_API_KEY + + if not api_key: + raise ValueError( + "OPENAI_API_KEY not found in settings. Please set it in your .env file." + ) + + from memory.common.llms.openai_provider import OpenAIProvider + + return OpenAIProvider(api_key=api_key, model=model) else: raise ValueError( f"Unknown provider for model: {model}. " - f"Supported providers: Anthropic (claude-*)" + f"Supported providers: Anthropic (anthropic/*), OpenAI (openai/*)" ) diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py index 2f230fb..aa5beb3 100644 --- a/src/memory/common/llms/openai_provider.py +++ b/src/memory/common/llms/openai_provider.py @@ -1,7 +1,8 @@ """OpenAI LLM provider implementation.""" +import json import logging -from typing import Any, AsyncIterator, Iterator, Optional +from typing import Any, AsyncIterator, Iterator import openai @@ -10,11 +11,8 @@ from memory.common.llms.base import ( ImageContent, LLMSettings, Message, - MessageContent, - MessageRole, StreamEvent, TextContent, - ThinkingContent, ToolDefinition, ToolResultContent, ToolUseContent, @@ -26,92 +24,132 @@ logger = logging.getLogger(__name__) class OpenAIProvider(BaseLLMProvider): """OpenAI LLM provider with streaming and tool support.""" + # Models that use max_completion_tokens instead of max_tokens + # These are reasoning models with different parameter requirements + NON_REASONING_MODELS = {"gpt-4o"} + + def __init__(self, api_key: str, model: str): + """ + Initialize the OpenAI provider. + + Args: + api_key: OpenAI API key + model: Model identifier + """ + super().__init__(api_key, model) + self._async_client: openai.AsyncOpenAI | None = None + + def _is_reasoning_model(self) -> bool: + """ + Check if the current model is a reasoning model (o1 series). + + Reasoning models have different parameter requirements: + - Use max_completion_tokens instead of max_tokens + - Don't support temperature (always uses temperature=1) + - Don't support top_p + - Don't support system messages via system parameter + """ + return self.model.lower() not in self.NON_REASONING_MODELS + def _initialize_client(self) -> openai.OpenAI: """Initialize the OpenAI client.""" return openai.OpenAI(api_key=self.api_key) + @property + def async_client(self) -> openai.AsyncOpenAI: + """Lazy-load the async client.""" + if self._async_client is None: + self._async_client = openai.AsyncOpenAI(api_key=self.api_key) + return self._async_client + + def _convert_text_content(self, content: TextContent) -> dict[str, Any]: + """Convert TextContent to OpenAI format.""" + return {"type": "text", "text": content.text} + + def _convert_image_content(self, content: ImageContent) -> dict[str, Any]: + """Convert ImageContent to OpenAI image_url format.""" + encoded_image = self.encode_image(content.image) + image_part: dict[str, Any] = { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, + } + if content.detail: + image_part["image_url"]["detail"] = content.detail + return image_part + + def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]: + """Convert ToolUseContent to provider format. Override for custom format.""" + return { + "id": content.id, + "type": "function", + "function": { + "name": content.name, + "arguments": json.dumps(content.input), + }, + } + + def _convert_tool_result_content( + self, content: ToolResultContent + ) -> dict[str, Any]: + """Convert ToolResultContent to provider format. Override for custom format.""" + return { + "role": "tool", + "tool_call_id": content.tool_use_id, + "content": content.content, + } + def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]: """ - Convert our Message format to OpenAI format. + Convert messages to OpenAI format. - Args: - messages: List of messages in our format + OpenAI has special requirements: + - ToolResultContent creates separate "tool" role messages + - ToolUseContent becomes tool_calls field on assistant messages + - One input Message can produce multiple output messages Returns: - List of messages in OpenAI format + Flat list of OpenAI-formatted message dicts """ - openai_messages = [] + openai_messages: list[dict[str, Any]] = [] - for msg in messages: - if isinstance(msg.content, str): - openai_messages.append({"role": msg.role.value, "content": msg.content}) - else: - # Handle multi-part content - content_parts = [] - for item in msg.content: - if isinstance(item, TextContent): - content_parts.append({"type": "text", "text": item.text}) - elif isinstance(item, ImageContent): - encoded_image = self.encode_image(item.image) - image_part: dict[str, Any] = { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{encoded_image}" - }, - } - if item.detail: - image_part["image_url"]["detail"] = item.detail - content_parts.append(image_part) - elif isinstance(item, ToolUseContent): - # OpenAI doesn't have tool_use in content, it's a separate field - # We'll handle this by adding a tool_calls field to the message - pass - elif isinstance(item, ToolResultContent): - # OpenAI handles tool results as separate "tool" role messages - openai_messages.append( - { - "role": "tool", - "tool_call_id": item.tool_use_id, - "content": item.content, - } - ) - continue - elif isinstance(item, ThinkingContent): - # OpenAI doesn't have native thinking support in most models - # We can add it as text with a special marker - content_parts.append( - { - "type": "text", - "text": f"[Thinking: {item.thinking}]", - } - ) + for message in messages: + # Handle simple string content + if isinstance(message.content, str): + openai_messages.append( + {"role": message.role.value, "content": message.content} + ) + continue - # Check if this message has tool calls - tool_calls = [ - item for item in msg.content if isinstance(item, ToolUseContent) - ] + # Handle multi-part content + content_parts: list[dict[str, Any]] = [] + tool_calls_list: list[dict[str, Any]] = [] - message_dict: dict[str, Any] = {"role": msg.role.value} + for item in message.content: + if isinstance(item, ToolResultContent): + openai_messages.append(self._convert_tool_result_content(item)) + elif isinstance(item, ToolUseContent): + tool_calls_list.append(self._convert_tool_use_content(item)) + else: + content_parts.append(self._convert_message_content(item)) + + if content_parts or tool_calls_list: + msg_dict: dict[str, Any] = {"role": message.role.value} if content_parts: - message_dict["content"] = content_parts + msg_dict["content"] = content_parts + elif tool_calls_list: + # Assistant messages with tool calls need content field (use empty string) + msg_dict["content"] = "" - if tool_calls: - message_dict["tool_calls"] = [ - { - "id": tc.id, - "type": "function", - "function": {"name": tc.name, "arguments": str(tc.input)}, - } - for tc in tool_calls - ] + if tool_calls_list: + msg_dict["tool_calls"] = tool_calls_list - openai_messages.append(message_dict) + openai_messages.append(msg_dict) return openai_messages def _convert_tools( - self, tools: Optional[list[ToolDefinition]] + self, tools: list[ToolDefinition] | None ) -> Optional[list[dict[str, Any]]]: """ Convert our tool definitions to OpenAI format. @@ -137,6 +175,153 @@ class OpenAIProvider(BaseLLMProvider): for tool in tools ] + def _build_request_kwargs( + self, + messages: list[Message], + system_prompt: str | None, + tools: Optional[list[ToolDefinition]], + settings: LLMSettings, + stream: bool = False, + ) -> dict[str, Any]: + """ + Build common request kwargs for API calls. + + Args: + messages: Conversation history + system_prompt: Optional system prompt + tools: Optional list of tools + settings: LLM settings + stream: Whether to enable streaming + + Returns: + Dictionary of kwargs for OpenAI API call + """ + openai_messages = self._convert_messages(messages) + is_reasoning = self._is_reasoning_model() + + # Log info for reasoning models on first use + if is_reasoning: + logger.debug( + f"Using reasoning model {self.model}: " + "max_completion_tokens will be used, temperature/top_p ignored" + ) + + # Reasoning models (o1) don't support system parameter + # System message must be added as a developer message instead + if system_prompt: + if is_reasoning: + # For o1 models, add system prompt as a developer message + openai_messages.insert( + 0, {"role": "developer", "content": system_prompt} + ) + else: + # For other models, add as system message + openai_messages.insert(0, {"role": "system", "content": system_prompt}) + + # Reasoning models use max_completion_tokens instead of max_tokens + max_tokens_key = "max_completion_tokens" if is_reasoning else "max_tokens" + + kwargs: dict[str, Any] = { + "model": self.model, + "messages": openai_messages, + max_tokens_key: settings.max_tokens, + } + + # Reasoning models don't support temperature or top_p + if not is_reasoning: + kwargs["temperature"] = settings.temperature + kwargs["top_p"] = settings.top_p + + if stream: + kwargs["stream"] = True + + if settings.stop_sequences: + kwargs["stop"] = settings.stop_sequences + + if tools: + kwargs["tools"] = self._convert_tools(tools) + kwargs["tool_choice"] = "auto" + + return kwargs + + def _parse_and_finalize_tool_call( + self, tool_call: dict[str, Any] + ) -> dict[str, Any]: + """ + Parse the accumulated tool call arguments and prepare for yielding. + + Args: + tool_call: Tool call dict with 'arguments' field (JSON string) + + Returns: + Tool call dict with parsed 'input' field (dict) + """ + try: + tool_call["input"] = json.loads(tool_call["arguments"]) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse tool arguments '{tool_call['arguments']}': {e}" + ) + tool_call["input"] = {} + del tool_call["arguments"] + return tool_call + + def _handle_stream_chunk( + self, + chunk: Any, + current_tool_call: dict[str, Any] | None, + ) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]: + """ + Handle a single streaming chunk and return events and updated tool state. + + Args: + chunk: Streaming chunk from OpenAI + current_tool_call: Current tool call being accumulated (or None) + + Returns: + Tuple of (list of StreamEvents to yield, updated current_tool_call) + """ + events: list[StreamEvent] = [] + + if not chunk.choices: + return events, current_tool_call + + delta = chunk.choices[0].delta + + # Handle text content + if delta.content: + events.append(StreamEvent(type="text", data=delta.content)) + + # Handle tool calls + if delta.tool_calls: + for tool_call in delta.tool_calls: + if tool_call.id: + # New tool call starting + if current_tool_call: + # Yield the previous one with parsed input + finalized = self._parse_and_finalize_tool_call( + current_tool_call + ) + events.append(StreamEvent(type="tool_use", data=finalized)) + current_tool_call = { + "id": tool_call.id, + "name": tool_call.function.name or "", + "arguments": tool_call.function.arguments or "", + } + elif current_tool_call and tool_call.function.arguments: + # Continue building the current tool call + current_tool_call["arguments"] += tool_call.function.arguments + + # Check if stream is finished + if chunk.choices[0].finish_reason: + if current_tool_call: + # Parse the final tool call arguments + finalized = self._parse_and_finalize_tool_call(current_tool_call) + events.append(StreamEvent(type="tool_use", data=finalized)) + current_tool_call = None + + return events, current_tool_call + def generate( self, messages: list[Message], @@ -146,29 +331,9 @@ class OpenAIProvider(BaseLLMProvider): ) -> str: """Generate a non-streaming response.""" settings = settings or LLMSettings() - - openai_messages = self._convert_messages(messages) - - # Add system prompt as first message if provided - if system_prompt: - openai_messages.insert( - 0, {"role": "system", "content": system_prompt} - ) - - kwargs: dict[str, Any] = { - "model": self.model, - "messages": openai_messages, - "temperature": settings.temperature, - "max_tokens": settings.max_tokens, - "top_p": settings.top_p, - } - - if settings.stop_sequences: - kwargs["stop"] = settings.stop_sequences - - if tools: - kwargs["tools"] = self._convert_tools(tools) - kwargs["tool_choice"] = "auto" + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, settings, stream=False + ) try: response = self.client.chat.completions.create(**kwargs) @@ -180,78 +345,25 @@ class OpenAIProvider(BaseLLMProvider): def stream( self, messages: list[Message], - system_prompt: Optional[str] = None, - tools: Optional[list[ToolDefinition]] = None, - settings: Optional[LLMSettings] = None, + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, ) -> Iterator[StreamEvent]: """Generate a streaming response.""" settings = settings or LLMSettings() - - openai_messages = self._convert_messages(messages) - - # Add system prompt as first message if provided - if system_prompt: - openai_messages.insert( - 0, {"role": "system", "content": system_prompt} - ) - - kwargs: dict[str, Any] = { - "model": self.model, - "messages": openai_messages, - "temperature": settings.temperature, - "max_tokens": settings.max_tokens, - "top_p": settings.top_p, - "stream": True, - } - - if settings.stop_sequences: - kwargs["stop"] = settings.stop_sequences - - if tools: - kwargs["tools"] = self._convert_tools(tools) - kwargs["tool_choice"] = "auto" + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, settings, stream=True + ) try: stream = self.client.chat.completions.create(**kwargs) - - current_tool_call: Optional[dict[str, Any]] = None + current_tool_call: dict[str, Any] | None = None for chunk in stream: - if not chunk.choices: - continue - - delta = chunk.choices[0].delta - - # Handle text content - if delta.content: - yield StreamEvent(type="text", data=delta.content) - - # Handle tool calls - if delta.tool_calls: - for tool_call in delta.tool_calls: - if tool_call.id: - # New tool call starting - if current_tool_call: - # Yield the previous one - yield StreamEvent( - type="tool_use", data=current_tool_call - ) - current_tool_call = { - "id": tool_call.id, - "name": tool_call.function.name or "", - "arguments": tool_call.function.arguments or "", - } - elif current_tool_call and tool_call.function.arguments: - # Continue building the current tool call - current_tool_call["arguments"] += ( - tool_call.function.arguments - ) - - # Check if stream is finished - if chunk.choices[0].finish_reason: - if current_tool_call: - yield StreamEvent(type="tool_use", data=current_tool_call) - current_tool_call = None + events, current_tool_call = self._handle_stream_chunk( + chunk, current_tool_call + ) + yield from events yield StreamEvent(type="done") @@ -268,35 +380,12 @@ class OpenAIProvider(BaseLLMProvider): ) -> str: """Generate a non-streaming response asynchronously.""" settings = settings or LLMSettings() - - # Use async client - async_client = openai.AsyncOpenAI(api_key=self.api_key) - - openai_messages = self._convert_messages(messages) - - # Add system prompt as first message if provided - if system_prompt: - openai_messages.insert( - 0, {"role": "system", "content": system_prompt} - ) - - kwargs: dict[str, Any] = { - "model": self.model, - "messages": openai_messages, - "temperature": settings.temperature, - "max_tokens": settings.max_tokens, - "top_p": settings.top_p, - } - - if settings.stop_sequences: - kwargs["stop"] = settings.stop_sequences - - if tools: - kwargs["tools"] = self._convert_tools(tools) - kwargs["tool_choice"] = "auto" + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, settings, stream=False + ) try: - response = await async_client.chat.completions.create(**kwargs) + response = await self.async_client.chat.completions.create(**kwargs) return response.choices[0].message.content or "" except Exception as e: logger.error(f"OpenAI API error: {e}") @@ -311,75 +400,20 @@ class OpenAIProvider(BaseLLMProvider): ) -> AsyncIterator[StreamEvent]: """Generate a streaming response asynchronously.""" settings = settings or LLMSettings() - - # Use async client - async_client = openai.AsyncOpenAI(api_key=self.api_key) - - openai_messages = self._convert_messages(messages) - - # Add system prompt as first message if provided - if system_prompt: - openai_messages.insert( - 0, {"role": "system", "content": system_prompt} - ) - - kwargs: dict[str, Any] = { - "model": self.model, - "messages": openai_messages, - "temperature": settings.temperature, - "max_tokens": settings.max_tokens, - "top_p": settings.top_p, - "stream": True, - } - - if settings.stop_sequences: - kwargs["stop"] = settings.stop_sequences - - if tools: - kwargs["tools"] = self._convert_tools(tools) - kwargs["tool_choice"] = "auto" + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, settings, stream=True + ) try: - stream = await async_client.chat.completions.create(**kwargs) - + stream = await self.async_client.chat.completions.create(**kwargs) current_tool_call: Optional[dict[str, Any]] = None async for chunk in stream: - if not chunk.choices: - continue - - delta = chunk.choices[0].delta - - # Handle text content - if delta.content: - yield StreamEvent(type="text", data=delta.content) - - # Handle tool calls - if delta.tool_calls: - for tool_call in delta.tool_calls: - if tool_call.id: - # New tool call starting - if current_tool_call: - # Yield the previous one - yield StreamEvent( - type="tool_use", data=current_tool_call - ) - current_tool_call = { - "id": tool_call.id, - "name": tool_call.function.name or "", - "arguments": tool_call.function.arguments or "", - } - elif current_tool_call and tool_call.function.arguments: - # Continue building the current tool call - current_tool_call["arguments"] += ( - tool_call.function.arguments - ) - - # Check if stream is finished - if chunk.choices[0].finish_reason: - if current_tool_call: - yield StreamEvent(type="tool_use", data=current_tool_call) - current_tool_call = None + events, current_tool_call = self._handle_stream_chunk( + chunk, current_tool_call + ) + for event in events: + yield event yield StreamEvent(type="done")