mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-22 22:56:38 +02:00
handle openai
This commit is contained in:
parent
99d3843f47
commit
e68671deb4
@ -6,7 +6,7 @@ dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.69.0
|
||||
openai==1.25.0
|
||||
openai==2.3.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
celery[sqs]==5.3.6
|
||||
|
@ -22,10 +22,14 @@ from memory.common.llms.base import (
|
||||
ToolUseContent,
|
||||
create_provider,
|
||||
)
|
||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||
from memory.common.llms.openai_provider import OpenAIProvider
|
||||
from memory.common import tokens
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMProvider",
|
||||
"AnthropicProvider",
|
||||
"OpenAIProvider",
|
||||
"Message",
|
||||
"MessageRole",
|
||||
"MessageContent",
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Iterator, Optional
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
|
||||
import anthropic
|
||||
|
||||
@ -14,9 +14,6 @@ from memory.common.llms.base import (
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
ToolDefinition,
|
||||
ToolUseContent,
|
||||
ThinkingContent,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -45,7 +42,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
"""
|
||||
super().__init__(api_key, model)
|
||||
self.enable_thinking = enable_thinking
|
||||
self._async_client: Optional[anthropic.AsyncAnthropic] = None
|
||||
self._async_client: anthropic.AsyncAnthropic | None = None
|
||||
|
||||
def _initialize_client(self) -> anthropic.Anthropic:
|
||||
"""Initialize the Anthropic client."""
|
||||
@ -336,6 +333,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
print(kwargs)
|
||||
try:
|
||||
with self.client.messages.stream(**kwargs) as stream:
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
@ -396,56 +394,3 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
|
||||
def stream_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Iterator[StreamEvent]:
|
||||
if max_iterations <= 0:
|
||||
return
|
||||
|
||||
response = TextContent(text="")
|
||||
thinking = ThinkingContent(thinking="", signature="")
|
||||
|
||||
for event in self.stream(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
tools=list(tools.values()),
|
||||
settings=settings,
|
||||
):
|
||||
if event.type == "text":
|
||||
response.text += event.data
|
||||
yield event
|
||||
elif event.type == "thinking" and event.signature:
|
||||
thinking.signature = event.signature
|
||||
elif event.type == "thinking":
|
||||
thinking.thinking += event.data
|
||||
yield event
|
||||
elif event.type == "tool_use":
|
||||
yield event
|
||||
tool_result = self.execute_tool(event.data, tools)
|
||||
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||
messages.append(
|
||||
Message.assistant(
|
||||
response,
|
||||
thinking,
|
||||
ToolUseContent(
|
||||
id=event.data["id"],
|
||||
name=event.data["name"],
|
||||
input=event.data["input"],
|
||||
),
|
||||
)
|
||||
)
|
||||
messages.append(Message.user(tool_result=tool_result))
|
||||
yield from self.stream_with_tools(
|
||||
messages, tools, settings, system_prompt, max_iterations - 1
|
||||
)
|
||||
elif event.type == "tool_result":
|
||||
yield event
|
||||
elif event.type == "error":
|
||||
logger.error(f"LLM error: {event.data}")
|
||||
raise RuntimeError(f"LLM error: {event.data}")
|
||||
|
@ -6,7 +6,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
|
||||
from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@ -36,6 +36,10 @@ class TextContent:
|
||||
"""Convert to dictionary format."""
|
||||
return {"type": "text", "text": self.text}
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return self.text
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
@ -43,13 +47,17 @@ class ImageContent:
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: Image.Image = None # type: ignore
|
||||
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
|
||||
detail: str | None = None # For OpenAI: "low", "high", "auto"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
# Note: Image will be encoded by provider-specific implementation
|
||||
return {"type": "image", "image": self.image}
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return self.image
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUseContent:
|
||||
@ -69,6 +77,10 @@ class ToolUseContent:
|
||||
"input": self.input,
|
||||
}
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return self.id and self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResultContent:
|
||||
@ -88,6 +100,10 @@ class ToolResultContent:
|
||||
"is_error": self.is_error,
|
||||
}
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return self.tool_use_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThinkingContent:
|
||||
@ -105,6 +121,10 @@ class ThinkingContent:
|
||||
"signature": self.signature,
|
||||
}
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return self.thinking and self.signature
|
||||
|
||||
|
||||
MessageContent = Union[
|
||||
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
|
||||
@ -135,18 +155,8 @@ class Message:
|
||||
return {"role": self.role.value, "content": content_list}
|
||||
|
||||
@staticmethod
|
||||
def assistant(
|
||||
text: TextContent | None = None,
|
||||
thinking: ThinkingContent | None = None,
|
||||
tool_use: ToolUseContent | None = None,
|
||||
) -> "Message":
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(text)
|
||||
if thinking:
|
||||
parts.append(thinking)
|
||||
if tool_use:
|
||||
parts.append(tool_use)
|
||||
def assistant(*content: MessageContent) -> "Message":
|
||||
parts = [c for c in content if c.valid]
|
||||
return Message(role=MessageRole.ASSISTANT, content=parts)
|
||||
|
||||
@staticmethod
|
||||
@ -295,13 +305,22 @@ class BaseLLMProvider(ABC):
|
||||
"""Convert ThinkingContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
|
||||
def _convert_message_content(
|
||||
self, content: str | MessageContent | list[MessageContent]
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
"""
|
||||
Convert a MessageContent item to provider format.
|
||||
|
||||
Dispatches to type-specific converters that can be overridden.
|
||||
"""
|
||||
if isinstance(content, TextContent):
|
||||
if isinstance(content, str):
|
||||
return self._convert_text_content(TextContent(text=content))
|
||||
elif isinstance(content, list):
|
||||
return [
|
||||
cast(dict[str, Any], self._convert_message_content(item))
|
||||
for item in content
|
||||
]
|
||||
elif isinstance(content, TextContent):
|
||||
return self._convert_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
return self._convert_image_content(content)
|
||||
@ -318,9 +337,16 @@ class BaseLLMProvider(ABC):
|
||||
"""
|
||||
Convert a Message to provider format.
|
||||
|
||||
Can be overridden for provider-specific handling (e.g., filtering system messages).
|
||||
Handles both string content and list[MessageContent], using provider-specific
|
||||
content converters for each content item.
|
||||
|
||||
Can be overridden for provider-specific handling (e.g., OpenAI's tool results).
|
||||
"""
|
||||
return message.to_dict()
|
||||
# Handle simple string content
|
||||
return {
|
||||
"role": message.role.value,
|
||||
"content": self._convert_message_content(message.content),
|
||||
}
|
||||
|
||||
def _should_include_message(self, message: Message) -> bool:
|
||||
"""
|
||||
@ -362,7 +388,7 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: list[ToolDefinition] | None
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert tool definitions to provider format."""
|
||||
if not tools:
|
||||
return None
|
||||
@ -456,7 +482,6 @@ class BaseLLMProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stream_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
@ -465,7 +490,75 @@ class BaseLLMProvider(ABC):
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Iterator[StreamEvent]:
|
||||
pass
|
||||
"""
|
||||
Stream response with automatic tool execution.
|
||||
|
||||
This method handles the tool call loop automatically, executing tools
|
||||
and sending results back to the LLM until it produces a final response
|
||||
or max_iterations is reached.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
tools: Dict mapping tool names to ToolDefinition handlers
|
||||
settings: Optional settings for the generation
|
||||
system_prompt: Optional system prompt
|
||||
max_iterations: Maximum number of tool call iterations
|
||||
|
||||
Yields:
|
||||
StreamEvent objects for text, tool calls, and tool results
|
||||
"""
|
||||
if max_iterations <= 0:
|
||||
return
|
||||
|
||||
response = TextContent(text="")
|
||||
thinking = ThinkingContent(thinking="")
|
||||
|
||||
for event in self.stream(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
tools=list(tools.values()),
|
||||
settings=settings,
|
||||
):
|
||||
if event.type == "text":
|
||||
response.text += event.data
|
||||
yield event
|
||||
elif event.type == "thinking":
|
||||
thinking.thinking += event.data
|
||||
yield event
|
||||
elif event.type == "tool_use":
|
||||
yield event
|
||||
# Execute the tool and yield the result
|
||||
tool_result = self.execute_tool(event.data, tools)
|
||||
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||
|
||||
# Add assistant message with tool call
|
||||
messages.append(
|
||||
Message.assistant(
|
||||
response,
|
||||
thinking,
|
||||
ToolUseContent(
|
||||
id=event.data["id"],
|
||||
name=event.data["name"],
|
||||
input=event.data["input"],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Add user message with tool result
|
||||
messages.append(Message.user(tool_result=tool_result))
|
||||
|
||||
# Recursively continue the conversation with reduced iterations
|
||||
yield from self.stream_with_tools(
|
||||
messages, tools, settings, system_prompt, max_iterations - 1
|
||||
)
|
||||
return # Exit after recursive call completes
|
||||
|
||||
elif event.type == "error":
|
||||
logger.error(f"LLM error: {event.data}")
|
||||
raise RuntimeError(f"LLM error: {event.data}")
|
||||
elif event.type == "done":
|
||||
# Stream completed without tool calls
|
||||
yield event
|
||||
|
||||
def run_with_tools(
|
||||
self,
|
||||
@ -550,12 +643,22 @@ def create_provider(
|
||||
api_key=api_key, model=model, enable_thinking=enable_thinking
|
||||
)
|
||||
|
||||
# Could add OpenAI support here in the future
|
||||
# elif "gpt" in model_lower or model.startswith("openai"):
|
||||
# ...
|
||||
elif provider == "openai":
|
||||
# OpenAI models
|
||||
if api_key is None:
|
||||
api_key = settings.OPENAI_API_KEY
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY not found in settings. Please set it in your .env file."
|
||||
)
|
||||
|
||||
from memory.common.llms.openai_provider import OpenAIProvider
|
||||
|
||||
return OpenAIProvider(api_key=api_key, model=model)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown provider for model: {model}. "
|
||||
f"Supported providers: Anthropic (claude-*)"
|
||||
f"Supported providers: Anthropic (anthropic/*), OpenAI (openai/*)"
|
||||
)
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""OpenAI LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Iterator, Optional
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
|
||||
import openai
|
||||
|
||||
@ -10,11 +11,8 @@ from memory.common.llms.base import (
|
||||
ImageContent,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageContent,
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolDefinition,
|
||||
ToolResultContent,
|
||||
ToolUseContent,
|
||||
@ -26,92 +24,132 @@ logger = logging.getLogger(__name__)
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""OpenAI LLM provider with streaming and tool support."""
|
||||
|
||||
# Models that use max_completion_tokens instead of max_tokens
|
||||
# These are reasoning models with different parameter requirements
|
||||
NON_REASONING_MODELS = {"gpt-4o"}
|
||||
|
||||
def __init__(self, api_key: str, model: str):
|
||||
"""
|
||||
Initialize the OpenAI provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model identifier
|
||||
"""
|
||||
super().__init__(api_key, model)
|
||||
self._async_client: openai.AsyncOpenAI | None = None
|
||||
|
||||
def _is_reasoning_model(self) -> bool:
|
||||
"""
|
||||
Check if the current model is a reasoning model (o1 series).
|
||||
|
||||
Reasoning models have different parameter requirements:
|
||||
- Use max_completion_tokens instead of max_tokens
|
||||
- Don't support temperature (always uses temperature=1)
|
||||
- Don't support top_p
|
||||
- Don't support system messages via system parameter
|
||||
"""
|
||||
return self.model.lower() not in self.NON_REASONING_MODELS
|
||||
|
||||
def _initialize_client(self) -> openai.OpenAI:
|
||||
"""Initialize the OpenAI client."""
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
def async_client(self) -> openai.AsyncOpenAI:
|
||||
"""Lazy-load the async client."""
|
||||
if self._async_client is None:
|
||||
self._async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
return self._async_client
|
||||
|
||||
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
|
||||
"""Convert TextContent to OpenAI format."""
|
||||
return {"type": "text", "text": content.text}
|
||||
|
||||
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||
"""Convert ImageContent to OpenAI image_url format."""
|
||||
encoded_image = self.encode_image(content.image)
|
||||
image_part: dict[str, Any] = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
}
|
||||
if content.detail:
|
||||
image_part["image_url"]["detail"] = content.detail
|
||||
return image_part
|
||||
|
||||
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
|
||||
"""Convert ToolUseContent to provider format. Override for custom format."""
|
||||
return {
|
||||
"id": content.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content.name,
|
||||
"arguments": json.dumps(content.input),
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_tool_result_content(
|
||||
self, content: ToolResultContent
|
||||
) -> dict[str, Any]:
|
||||
"""Convert ToolResultContent to provider format. Override for custom format."""
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": content.tool_use_id,
|
||||
"content": content.content,
|
||||
}
|
||||
|
||||
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert our Message format to OpenAI format.
|
||||
Convert messages to OpenAI format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in our format
|
||||
OpenAI has special requirements:
|
||||
- ToolResultContent creates separate "tool" role messages
|
||||
- ToolUseContent becomes tool_calls field on assistant messages
|
||||
- One input Message can produce multiple output messages
|
||||
|
||||
Returns:
|
||||
List of messages in OpenAI format
|
||||
Flat list of OpenAI-formatted message dicts
|
||||
"""
|
||||
openai_messages = []
|
||||
openai_messages: list[dict[str, Any]] = []
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
openai_messages.append({"role": msg.role.value, "content": msg.content})
|
||||
else:
|
||||
# Handle multi-part content
|
||||
content_parts = []
|
||||
for item in msg.content:
|
||||
if isinstance(item, TextContent):
|
||||
content_parts.append({"type": "text", "text": item.text})
|
||||
elif isinstance(item, ImageContent):
|
||||
encoded_image = self.encode_image(item.image)
|
||||
image_part: dict[str, Any] = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}"
|
||||
},
|
||||
}
|
||||
if item.detail:
|
||||
image_part["image_url"]["detail"] = item.detail
|
||||
content_parts.append(image_part)
|
||||
elif isinstance(item, ToolUseContent):
|
||||
# OpenAI doesn't have tool_use in content, it's a separate field
|
||||
# We'll handle this by adding a tool_calls field to the message
|
||||
pass
|
||||
elif isinstance(item, ToolResultContent):
|
||||
# OpenAI handles tool results as separate "tool" role messages
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item.tool_use_id,
|
||||
"content": item.content,
|
||||
}
|
||||
)
|
||||
continue
|
||||
elif isinstance(item, ThinkingContent):
|
||||
# OpenAI doesn't have native thinking support in most models
|
||||
# We can add it as text with a special marker
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[Thinking: {item.thinking}]",
|
||||
}
|
||||
)
|
||||
for message in messages:
|
||||
# Handle simple string content
|
||||
if isinstance(message.content, str):
|
||||
openai_messages.append(
|
||||
{"role": message.role.value, "content": message.content}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if this message has tool calls
|
||||
tool_calls = [
|
||||
item for item in msg.content if isinstance(item, ToolUseContent)
|
||||
]
|
||||
# Handle multi-part content
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls_list: list[dict[str, Any]] = []
|
||||
|
||||
message_dict: dict[str, Any] = {"role": msg.role.value}
|
||||
for item in message.content:
|
||||
if isinstance(item, ToolResultContent):
|
||||
openai_messages.append(self._convert_tool_result_content(item))
|
||||
elif isinstance(item, ToolUseContent):
|
||||
tool_calls_list.append(self._convert_tool_use_content(item))
|
||||
else:
|
||||
content_parts.append(self._convert_message_content(item))
|
||||
|
||||
if content_parts or tool_calls_list:
|
||||
msg_dict: dict[str, Any] = {"role": message.role.value}
|
||||
|
||||
if content_parts:
|
||||
message_dict["content"] = content_parts
|
||||
msg_dict["content"] = content_parts
|
||||
elif tool_calls_list:
|
||||
# Assistant messages with tool calls need content field (use empty string)
|
||||
msg_dict["content"] = ""
|
||||
|
||||
if tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {"name": tc.name, "arguments": str(tc.input)},
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
if tool_calls_list:
|
||||
msg_dict["tool_calls"] = tool_calls_list
|
||||
|
||||
openai_messages.append(message_dict)
|
||||
openai_messages.append(msg_dict)
|
||||
|
||||
return openai_messages
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Optional[list[ToolDefinition]]
|
||||
self, tools: list[ToolDefinition] | None
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""
|
||||
Convert our tool definitions to OpenAI format.
|
||||
@ -137,6 +175,153 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
def _build_request_kwargs(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None,
|
||||
tools: Optional[list[ToolDefinition]],
|
||||
settings: LLMSettings,
|
||||
stream: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build common request kwargs for API calls.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools
|
||||
settings: LLM settings
|
||||
stream: Whether to enable streaming
|
||||
|
||||
Returns:
|
||||
Dictionary of kwargs for OpenAI API call
|
||||
"""
|
||||
openai_messages = self._convert_messages(messages)
|
||||
is_reasoning = self._is_reasoning_model()
|
||||
|
||||
# Log info for reasoning models on first use
|
||||
if is_reasoning:
|
||||
logger.debug(
|
||||
f"Using reasoning model {self.model}: "
|
||||
"max_completion_tokens will be used, temperature/top_p ignored"
|
||||
)
|
||||
|
||||
# Reasoning models (o1) don't support system parameter
|
||||
# System message must be added as a developer message instead
|
||||
if system_prompt:
|
||||
if is_reasoning:
|
||||
# For o1 models, add system prompt as a developer message
|
||||
openai_messages.insert(
|
||||
0, {"role": "developer", "content": system_prompt}
|
||||
)
|
||||
else:
|
||||
# For other models, add as system message
|
||||
openai_messages.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
# Reasoning models use max_completion_tokens instead of max_tokens
|
||||
max_tokens_key = "max_completion_tokens" if is_reasoning else "max_tokens"
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
max_tokens_key: settings.max_tokens,
|
||||
}
|
||||
|
||||
# Reasoning models don't support temperature or top_p
|
||||
if not is_reasoning:
|
||||
kwargs["temperature"] = settings.temperature
|
||||
kwargs["top_p"] = settings.top_p
|
||||
|
||||
if stream:
|
||||
kwargs["stream"] = True
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
return kwargs
|
||||
|
||||
def _parse_and_finalize_tool_call(
|
||||
self, tool_call: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Parse the accumulated tool call arguments and prepare for yielding.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call dict with 'arguments' field (JSON string)
|
||||
|
||||
Returns:
|
||||
Tool call dict with parsed 'input' field (dict)
|
||||
"""
|
||||
try:
|
||||
tool_call["input"] = json.loads(tool_call["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to parse tool arguments '{tool_call['arguments']}': {e}"
|
||||
)
|
||||
tool_call["input"] = {}
|
||||
del tool_call["arguments"]
|
||||
return tool_call
|
||||
|
||||
def _handle_stream_chunk(
|
||||
self,
|
||||
chunk: Any,
|
||||
current_tool_call: dict[str, Any] | None,
|
||||
) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Handle a single streaming chunk and return events and updated tool state.
|
||||
|
||||
Args:
|
||||
chunk: Streaming chunk from OpenAI
|
||||
current_tool_call: Current tool call being accumulated (or None)
|
||||
|
||||
Returns:
|
||||
Tuple of (list of StreamEvents to yield, updated current_tool_call)
|
||||
"""
|
||||
events: list[StreamEvent] = []
|
||||
|
||||
if not chunk.choices:
|
||||
return events, current_tool_call
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
events.append(StreamEvent(type="text", data=delta.content))
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.id:
|
||||
# New tool call starting
|
||||
if current_tool_call:
|
||||
# Yield the previous one with parsed input
|
||||
finalized = self._parse_and_finalize_tool_call(
|
||||
current_tool_call
|
||||
)
|
||||
events.append(StreamEvent(type="tool_use", data=finalized))
|
||||
current_tool_call = {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name or "",
|
||||
"arguments": tool_call.function.arguments or "",
|
||||
}
|
||||
elif current_tool_call and tool_call.function.arguments:
|
||||
# Continue building the current tool call
|
||||
current_tool_call["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Check if stream is finished
|
||||
if chunk.choices[0].finish_reason:
|
||||
if current_tool_call:
|
||||
# Parse the final tool call arguments
|
||||
finalized = self._parse_and_finalize_tool_call(current_tool_call)
|
||||
events.append(StreamEvent(type="tool_use", data=finalized))
|
||||
current_tool_call = None
|
||||
|
||||
return events, current_tool_call
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
@ -146,29 +331,9 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
) -> str:
|
||||
"""Generate a non-streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=False
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
@ -180,78 +345,25 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
def stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=True
|
||||
)
|
||||
|
||||
try:
|
||||
stream = self.client.chat.completions.create(**kwargs)
|
||||
|
||||
current_tool_call: Optional[dict[str, Any]] = None
|
||||
current_tool_call: dict[str, Any] | None = None
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
yield StreamEvent(type="text", data=delta.content)
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.id:
|
||||
# New tool call starting
|
||||
if current_tool_call:
|
||||
# Yield the previous one
|
||||
yield StreamEvent(
|
||||
type="tool_use", data=current_tool_call
|
||||
)
|
||||
current_tool_call = {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name or "",
|
||||
"arguments": tool_call.function.arguments or "",
|
||||
}
|
||||
elif current_tool_call and tool_call.function.arguments:
|
||||
# Continue building the current tool call
|
||||
current_tool_call["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Check if stream is finished
|
||||
if chunk.choices[0].finish_reason:
|
||||
if current_tool_call:
|
||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||
current_tool_call = None
|
||||
events, current_tool_call = self._handle_stream_chunk(
|
||||
chunk, current_tool_call
|
||||
)
|
||||
yield from events
|
||||
|
||||
yield StreamEvent(type="done")
|
||||
|
||||
@ -268,35 +380,12 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
) -> str:
|
||||
"""Generate a non-streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
# Use async client
|
||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=False
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.chat.completions.create(**kwargs)
|
||||
response = await self.async_client.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
@ -311,75 +400,20 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
# Use async client
|
||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=True
|
||||
)
|
||||
|
||||
try:
|
||||
stream = await async_client.chat.completions.create(**kwargs)
|
||||
|
||||
stream = await self.async_client.chat.completions.create(**kwargs)
|
||||
current_tool_call: Optional[dict[str, Any]] = None
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
yield StreamEvent(type="text", data=delta.content)
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.id:
|
||||
# New tool call starting
|
||||
if current_tool_call:
|
||||
# Yield the previous one
|
||||
yield StreamEvent(
|
||||
type="tool_use", data=current_tool_call
|
||||
)
|
||||
current_tool_call = {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name or "",
|
||||
"arguments": tool_call.function.arguments or "",
|
||||
}
|
||||
elif current_tool_call and tool_call.function.arguments:
|
||||
# Continue building the current tool call
|
||||
current_tool_call["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Check if stream is finished
|
||||
if chunk.choices[0].finish_reason:
|
||||
if current_tool_call:
|
||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||
current_tool_call = None
|
||||
events, current_tool_call = self._handle_stream_chunk(
|
||||
chunk, current_tool_call
|
||||
)
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
yield StreamEvent(type="done")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user