mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-23 07:06:36 +02:00
handle openai
This commit is contained in:
parent
99d3843f47
commit
e68671deb4
@ -6,7 +6,7 @@ dotenv==0.9.9
|
|||||||
voyageai==0.3.2
|
voyageai==0.3.2
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
anthropic==0.69.0
|
anthropic==0.69.0
|
||||||
openai==1.25.0
|
openai==2.3.0
|
||||||
# Pin the httpx version, as newer versions break the anthropic client
|
# Pin the httpx version, as newer versions break the anthropic client
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
celery[sqs]==5.3.6
|
celery[sqs]==5.3.6
|
||||||
|
@ -22,10 +22,14 @@ from memory.common.llms.base import (
|
|||||||
ToolUseContent,
|
ToolUseContent,
|
||||||
create_provider,
|
create_provider,
|
||||||
)
|
)
|
||||||
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
from memory.common import tokens
|
from memory.common import tokens
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseLLMProvider",
|
"BaseLLMProvider",
|
||||||
|
"AnthropicProvider",
|
||||||
|
"OpenAIProvider",
|
||||||
"Message",
|
"Message",
|
||||||
"MessageRole",
|
"MessageRole",
|
||||||
"MessageContent",
|
"MessageContent",
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncIterator, Iterator, Optional
|
from typing import Any, AsyncIterator, Iterator
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
@ -14,9 +14,6 @@ from memory.common.llms.base import (
|
|||||||
MessageRole,
|
MessageRole,
|
||||||
StreamEvent,
|
StreamEvent,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolUseContent,
|
|
||||||
ThinkingContent,
|
|
||||||
TextContent,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -45,7 +42,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
"""
|
"""
|
||||||
super().__init__(api_key, model)
|
super().__init__(api_key, model)
|
||||||
self.enable_thinking = enable_thinking
|
self.enable_thinking = enable_thinking
|
||||||
self._async_client: Optional[anthropic.AsyncAnthropic] = None
|
self._async_client: anthropic.AsyncAnthropic | None = None
|
||||||
|
|
||||||
def _initialize_client(self) -> anthropic.Anthropic:
|
def _initialize_client(self) -> anthropic.Anthropic:
|
||||||
"""Initialize the Anthropic client."""
|
"""Initialize the Anthropic client."""
|
||||||
@ -336,6 +333,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||||
|
|
||||||
|
print(kwargs)
|
||||||
try:
|
try:
|
||||||
with self.client.messages.stream(**kwargs) as stream:
|
with self.client.messages.stream(**kwargs) as stream:
|
||||||
current_tool_use: dict[str, Any] | None = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
@ -396,56 +394,3 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Anthropic streaming error: {e}")
|
logger.error(f"Anthropic streaming error: {e}")
|
||||||
yield StreamEvent(type="error", data=str(e))
|
yield StreamEvent(type="error", data=str(e))
|
||||||
|
|
||||||
def stream_with_tools(
|
|
||||||
self,
|
|
||||||
messages: list[Message],
|
|
||||||
tools: dict[str, ToolDefinition],
|
|
||||||
settings: LLMSettings | None = None,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
max_iterations: int = 10,
|
|
||||||
) -> Iterator[StreamEvent]:
|
|
||||||
if max_iterations <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
response = TextContent(text="")
|
|
||||||
thinking = ThinkingContent(thinking="", signature="")
|
|
||||||
|
|
||||||
for event in self.stream(
|
|
||||||
messages=messages,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=list(tools.values()),
|
|
||||||
settings=settings,
|
|
||||||
):
|
|
||||||
if event.type == "text":
|
|
||||||
response.text += event.data
|
|
||||||
yield event
|
|
||||||
elif event.type == "thinking" and event.signature:
|
|
||||||
thinking.signature = event.signature
|
|
||||||
elif event.type == "thinking":
|
|
||||||
thinking.thinking += event.data
|
|
||||||
yield event
|
|
||||||
elif event.type == "tool_use":
|
|
||||||
yield event
|
|
||||||
tool_result = self.execute_tool(event.data, tools)
|
|
||||||
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
|
||||||
messages.append(
|
|
||||||
Message.assistant(
|
|
||||||
response,
|
|
||||||
thinking,
|
|
||||||
ToolUseContent(
|
|
||||||
id=event.data["id"],
|
|
||||||
name=event.data["name"],
|
|
||||||
input=event.data["input"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
messages.append(Message.user(tool_result=tool_result))
|
|
||||||
yield from self.stream_with_tools(
|
|
||||||
messages, tools, settings, system_prompt, max_iterations - 1
|
|
||||||
)
|
|
||||||
elif event.type == "tool_result":
|
|
||||||
yield event
|
|
||||||
elif event.type == "error":
|
|
||||||
logger.error(f"LLM error: {event.data}")
|
|
||||||
raise RuntimeError(f"LLM error: {event.data}")
|
|
||||||
|
@ -6,7 +6,7 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
|
from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -36,6 +36,10 @@ class TextContent:
|
|||||||
"""Convert to dictionary format."""
|
"""Convert to dictionary format."""
|
||||||
return {"type": "text", "text": self.text}
|
return {"type": "text", "text": self.text}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return self.text
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImageContent:
|
class ImageContent:
|
||||||
@ -43,13 +47,17 @@ class ImageContent:
|
|||||||
|
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
image: Image.Image = None # type: ignore
|
image: Image.Image = None # type: ignore
|
||||||
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
|
detail: str | None = None # For OpenAI: "low", "high", "auto"
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Convert to dictionary format."""
|
"""Convert to dictionary format."""
|
||||||
# Note: Image will be encoded by provider-specific implementation
|
# Note: Image will be encoded by provider-specific implementation
|
||||||
return {"type": "image", "image": self.image}
|
return {"type": "image", "image": self.image}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return self.image
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolUseContent:
|
class ToolUseContent:
|
||||||
@ -69,6 +77,10 @@ class ToolUseContent:
|
|||||||
"input": self.input,
|
"input": self.input,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return self.id and self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolResultContent:
|
class ToolResultContent:
|
||||||
@ -88,6 +100,10 @@ class ToolResultContent:
|
|||||||
"is_error": self.is_error,
|
"is_error": self.is_error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return self.tool_use_id
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ThinkingContent:
|
class ThinkingContent:
|
||||||
@ -105,6 +121,10 @@ class ThinkingContent:
|
|||||||
"signature": self.signature,
|
"signature": self.signature,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return self.thinking and self.signature
|
||||||
|
|
||||||
|
|
||||||
MessageContent = Union[
|
MessageContent = Union[
|
||||||
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
|
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
|
||||||
@ -135,18 +155,8 @@ class Message:
|
|||||||
return {"role": self.role.value, "content": content_list}
|
return {"role": self.role.value, "content": content_list}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def assistant(
|
def assistant(*content: MessageContent) -> "Message":
|
||||||
text: TextContent | None = None,
|
parts = [c for c in content if c.valid]
|
||||||
thinking: ThinkingContent | None = None,
|
|
||||||
tool_use: ToolUseContent | None = None,
|
|
||||||
) -> "Message":
|
|
||||||
parts = []
|
|
||||||
if text:
|
|
||||||
parts.append(text)
|
|
||||||
if thinking:
|
|
||||||
parts.append(thinking)
|
|
||||||
if tool_use:
|
|
||||||
parts.append(tool_use)
|
|
||||||
return Message(role=MessageRole.ASSISTANT, content=parts)
|
return Message(role=MessageRole.ASSISTANT, content=parts)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -295,13 +305,22 @@ class BaseLLMProvider(ABC):
|
|||||||
"""Convert ThinkingContent to provider format. Override for custom format."""
|
"""Convert ThinkingContent to provider format. Override for custom format."""
|
||||||
return content.to_dict()
|
return content.to_dict()
|
||||||
|
|
||||||
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
|
def _convert_message_content(
|
||||||
|
self, content: str | MessageContent | list[MessageContent]
|
||||||
|
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Convert a MessageContent item to provider format.
|
Convert a MessageContent item to provider format.
|
||||||
|
|
||||||
Dispatches to type-specific converters that can be overridden.
|
Dispatches to type-specific converters that can be overridden.
|
||||||
"""
|
"""
|
||||||
if isinstance(content, TextContent):
|
if isinstance(content, str):
|
||||||
|
return self._convert_text_content(TextContent(text=content))
|
||||||
|
elif isinstance(content, list):
|
||||||
|
return [
|
||||||
|
cast(dict[str, Any], self._convert_message_content(item))
|
||||||
|
for item in content
|
||||||
|
]
|
||||||
|
elif isinstance(content, TextContent):
|
||||||
return self._convert_text_content(content)
|
return self._convert_text_content(content)
|
||||||
elif isinstance(content, ImageContent):
|
elif isinstance(content, ImageContent):
|
||||||
return self._convert_image_content(content)
|
return self._convert_image_content(content)
|
||||||
@ -318,9 +337,16 @@ class BaseLLMProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
Convert a Message to provider format.
|
Convert a Message to provider format.
|
||||||
|
|
||||||
Can be overridden for provider-specific handling (e.g., filtering system messages).
|
Handles both string content and list[MessageContent], using provider-specific
|
||||||
|
content converters for each content item.
|
||||||
|
|
||||||
|
Can be overridden for provider-specific handling (e.g., OpenAI's tool results).
|
||||||
"""
|
"""
|
||||||
return message.to_dict()
|
# Handle simple string content
|
||||||
|
return {
|
||||||
|
"role": message.role.value,
|
||||||
|
"content": self._convert_message_content(message.content),
|
||||||
|
}
|
||||||
|
|
||||||
def _should_include_message(self, message: Message) -> bool:
|
def _should_include_message(self, message: Message) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -362,7 +388,7 @@ class BaseLLMProvider(ABC):
|
|||||||
|
|
||||||
def _convert_tools(
|
def _convert_tools(
|
||||||
self, tools: list[ToolDefinition] | None
|
self, tools: list[ToolDefinition] | None
|
||||||
) -> Optional[list[dict[str, Any]]]:
|
) -> list[dict[str, Any]] | None:
|
||||||
"""Convert tool definitions to provider format."""
|
"""Convert tool definitions to provider format."""
|
||||||
if not tools:
|
if not tools:
|
||||||
return None
|
return None
|
||||||
@ -456,7 +482,6 @@ class BaseLLMProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def stream_with_tools(
|
def stream_with_tools(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@ -465,7 +490,75 @@ class BaseLLMProvider(ABC):
|
|||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
max_iterations: int = 10,
|
max_iterations: int = 10,
|
||||||
) -> Iterator[StreamEvent]:
|
) -> Iterator[StreamEvent]:
|
||||||
pass
|
"""
|
||||||
|
Stream response with automatic tool execution.
|
||||||
|
|
||||||
|
This method handles the tool call loop automatically, executing tools
|
||||||
|
and sending results back to the LLM until it produces a final response
|
||||||
|
or max_iterations is reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
tools: Dict mapping tool names to ToolDefinition handlers
|
||||||
|
settings: Optional settings for the generation
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
max_iterations: Maximum number of tool call iterations
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamEvent objects for text, tool calls, and tool results
|
||||||
|
"""
|
||||||
|
if max_iterations <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
response = TextContent(text="")
|
||||||
|
thinking = ThinkingContent(thinking="")
|
||||||
|
|
||||||
|
for event in self.stream(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=list(tools.values()),
|
||||||
|
settings=settings,
|
||||||
|
):
|
||||||
|
if event.type == "text":
|
||||||
|
response.text += event.data
|
||||||
|
yield event
|
||||||
|
elif event.type == "thinking":
|
||||||
|
thinking.thinking += event.data
|
||||||
|
yield event
|
||||||
|
elif event.type == "tool_use":
|
||||||
|
yield event
|
||||||
|
# Execute the tool and yield the result
|
||||||
|
tool_result = self.execute_tool(event.data, tools)
|
||||||
|
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||||
|
|
||||||
|
# Add assistant message with tool call
|
||||||
|
messages.append(
|
||||||
|
Message.assistant(
|
||||||
|
response,
|
||||||
|
thinking,
|
||||||
|
ToolUseContent(
|
||||||
|
id=event.data["id"],
|
||||||
|
name=event.data["name"],
|
||||||
|
input=event.data["input"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add user message with tool result
|
||||||
|
messages.append(Message.user(tool_result=tool_result))
|
||||||
|
|
||||||
|
# Recursively continue the conversation with reduced iterations
|
||||||
|
yield from self.stream_with_tools(
|
||||||
|
messages, tools, settings, system_prompt, max_iterations - 1
|
||||||
|
)
|
||||||
|
return # Exit after recursive call completes
|
||||||
|
|
||||||
|
elif event.type == "error":
|
||||||
|
logger.error(f"LLM error: {event.data}")
|
||||||
|
raise RuntimeError(f"LLM error: {event.data}")
|
||||||
|
elif event.type == "done":
|
||||||
|
# Stream completed without tool calls
|
||||||
|
yield event
|
||||||
|
|
||||||
def run_with_tools(
|
def run_with_tools(
|
||||||
self,
|
self,
|
||||||
@ -550,12 +643,22 @@ def create_provider(
|
|||||||
api_key=api_key, model=model, enable_thinking=enable_thinking
|
api_key=api_key, model=model, enable_thinking=enable_thinking
|
||||||
)
|
)
|
||||||
|
|
||||||
# Could add OpenAI support here in the future
|
elif provider == "openai":
|
||||||
# elif "gpt" in model_lower or model.startswith("openai"):
|
# OpenAI models
|
||||||
# ...
|
if api_key is None:
|
||||||
|
api_key = settings.OPENAI_API_KEY
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY not found in settings. Please set it in your .env file."
|
||||||
|
)
|
||||||
|
|
||||||
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
return OpenAIProvider(api_key=api_key, model=model)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider for model: {model}. "
|
f"Unknown provider for model: {model}. "
|
||||||
f"Supported providers: Anthropic (claude-*)"
|
f"Supported providers: Anthropic (anthropic/*), OpenAI (openai/*)"
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
"""OpenAI LLM provider implementation."""
|
"""OpenAI LLM provider implementation."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncIterator, Iterator, Optional
|
from typing import Any, AsyncIterator, Iterator
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
@ -10,11 +11,8 @@ from memory.common.llms.base import (
|
|||||||
ImageContent,
|
ImageContent,
|
||||||
LLMSettings,
|
LLMSettings,
|
||||||
Message,
|
Message,
|
||||||
MessageContent,
|
|
||||||
MessageRole,
|
|
||||||
StreamEvent,
|
StreamEvent,
|
||||||
TextContent,
|
TextContent,
|
||||||
ThinkingContent,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResultContent,
|
ToolResultContent,
|
||||||
ToolUseContent,
|
ToolUseContent,
|
||||||
@ -26,92 +24,132 @@ logger = logging.getLogger(__name__)
|
|||||||
class OpenAIProvider(BaseLLMProvider):
|
class OpenAIProvider(BaseLLMProvider):
|
||||||
"""OpenAI LLM provider with streaming and tool support."""
|
"""OpenAI LLM provider with streaming and tool support."""
|
||||||
|
|
||||||
|
# Models that use max_completion_tokens instead of max_tokens
|
||||||
|
# These are reasoning models with different parameter requirements
|
||||||
|
NON_REASONING_MODELS = {"gpt-4o"}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str):
|
||||||
|
"""
|
||||||
|
Initialize the OpenAI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API key
|
||||||
|
model: Model identifier
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, model)
|
||||||
|
self._async_client: openai.AsyncOpenAI | None = None
|
||||||
|
|
||||||
|
def _is_reasoning_model(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current model is a reasoning model (o1 series).
|
||||||
|
|
||||||
|
Reasoning models have different parameter requirements:
|
||||||
|
- Use max_completion_tokens instead of max_tokens
|
||||||
|
- Don't support temperature (always uses temperature=1)
|
||||||
|
- Don't support top_p
|
||||||
|
- Don't support system messages via system parameter
|
||||||
|
"""
|
||||||
|
return self.model.lower() not in self.NON_REASONING_MODELS
|
||||||
|
|
||||||
def _initialize_client(self) -> openai.OpenAI:
|
def _initialize_client(self) -> openai.OpenAI:
|
||||||
"""Initialize the OpenAI client."""
|
"""Initialize the OpenAI client."""
|
||||||
return openai.OpenAI(api_key=self.api_key)
|
return openai.OpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def async_client(self) -> openai.AsyncOpenAI:
|
||||||
|
"""Lazy-load the async client."""
|
||||||
|
if self._async_client is None:
|
||||||
|
self._async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||||
|
return self._async_client
|
||||||
|
|
||||||
|
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
|
||||||
|
"""Convert TextContent to OpenAI format."""
|
||||||
|
return {"type": "text", "text": content.text}
|
||||||
|
|
||||||
|
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||||
|
"""Convert ImageContent to OpenAI image_url format."""
|
||||||
|
encoded_image = self.encode_image(content.image)
|
||||||
|
image_part: dict[str, Any] = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||||
|
}
|
||||||
|
if content.detail:
|
||||||
|
image_part["image_url"]["detail"] = content.detail
|
||||||
|
return image_part
|
||||||
|
|
||||||
|
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
|
||||||
|
"""Convert ToolUseContent to provider format. Override for custom format."""
|
||||||
|
return {
|
||||||
|
"id": content.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": content.name,
|
||||||
|
"arguments": json.dumps(content.input),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_tool_result_content(
|
||||||
|
self, content: ToolResultContent
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Convert ToolResultContent to provider format. Override for custom format."""
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": content.tool_use_id,
|
||||||
|
"content": content.content,
|
||||||
|
}
|
||||||
|
|
||||||
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Convert our Message format to OpenAI format.
|
Convert messages to OpenAI format.
|
||||||
|
|
||||||
Args:
|
OpenAI has special requirements:
|
||||||
messages: List of messages in our format
|
- ToolResultContent creates separate "tool" role messages
|
||||||
|
- ToolUseContent becomes tool_calls field on assistant messages
|
||||||
|
- One input Message can produce multiple output messages
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of messages in OpenAI format
|
Flat list of OpenAI-formatted message dicts
|
||||||
"""
|
"""
|
||||||
openai_messages = []
|
openai_messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
for msg in messages:
|
for message in messages:
|
||||||
if isinstance(msg.content, str):
|
# Handle simple string content
|
||||||
openai_messages.append({"role": msg.role.value, "content": msg.content})
|
if isinstance(message.content, str):
|
||||||
else:
|
openai_messages.append(
|
||||||
# Handle multi-part content
|
{"role": message.role.value, "content": message.content}
|
||||||
content_parts = []
|
)
|
||||||
for item in msg.content:
|
continue
|
||||||
if isinstance(item, TextContent):
|
|
||||||
content_parts.append({"type": "text", "text": item.text})
|
|
||||||
elif isinstance(item, ImageContent):
|
|
||||||
encoded_image = self.encode_image(item.image)
|
|
||||||
image_part: dict[str, Any] = {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{encoded_image}"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if item.detail:
|
|
||||||
image_part["image_url"]["detail"] = item.detail
|
|
||||||
content_parts.append(image_part)
|
|
||||||
elif isinstance(item, ToolUseContent):
|
|
||||||
# OpenAI doesn't have tool_use in content, it's a separate field
|
|
||||||
# We'll handle this by adding a tool_calls field to the message
|
|
||||||
pass
|
|
||||||
elif isinstance(item, ToolResultContent):
|
|
||||||
# OpenAI handles tool results as separate "tool" role messages
|
|
||||||
openai_messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": item.tool_use_id,
|
|
||||||
"content": item.content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
elif isinstance(item, ThinkingContent):
|
|
||||||
# OpenAI doesn't have native thinking support in most models
|
|
||||||
# We can add it as text with a special marker
|
|
||||||
content_parts.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f"[Thinking: {item.thinking}]",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if this message has tool calls
|
# Handle multi-part content
|
||||||
tool_calls = [
|
content_parts: list[dict[str, Any]] = []
|
||||||
item for item in msg.content if isinstance(item, ToolUseContent)
|
tool_calls_list: list[dict[str, Any]] = []
|
||||||
]
|
|
||||||
|
|
||||||
message_dict: dict[str, Any] = {"role": msg.role.value}
|
for item in message.content:
|
||||||
|
if isinstance(item, ToolResultContent):
|
||||||
|
openai_messages.append(self._convert_tool_result_content(item))
|
||||||
|
elif isinstance(item, ToolUseContent):
|
||||||
|
tool_calls_list.append(self._convert_tool_use_content(item))
|
||||||
|
else:
|
||||||
|
content_parts.append(self._convert_message_content(item))
|
||||||
|
|
||||||
|
if content_parts or tool_calls_list:
|
||||||
|
msg_dict: dict[str, Any] = {"role": message.role.value}
|
||||||
|
|
||||||
if content_parts:
|
if content_parts:
|
||||||
message_dict["content"] = content_parts
|
msg_dict["content"] = content_parts
|
||||||
|
elif tool_calls_list:
|
||||||
|
# Assistant messages with tool calls need content field (use empty string)
|
||||||
|
msg_dict["content"] = ""
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls_list:
|
||||||
message_dict["tool_calls"] = [
|
msg_dict["tool_calls"] = tool_calls_list
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": tc.name, "arguments": str(tc.input)},
|
|
||||||
}
|
|
||||||
for tc in tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
openai_messages.append(message_dict)
|
openai_messages.append(msg_dict)
|
||||||
|
|
||||||
return openai_messages
|
return openai_messages
|
||||||
|
|
||||||
def _convert_tools(
|
def _convert_tools(
|
||||||
self, tools: Optional[list[ToolDefinition]]
|
self, tools: list[ToolDefinition] | None
|
||||||
) -> Optional[list[dict[str, Any]]]:
|
) -> Optional[list[dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Convert our tool definitions to OpenAI format.
|
Convert our tool definitions to OpenAI format.
|
||||||
@ -137,6 +175,153 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
for tool in tools
|
for tool in tools
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _build_request_kwargs(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None,
|
||||||
|
tools: Optional[list[ToolDefinition]],
|
||||||
|
settings: LLMSettings,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build common request kwargs for API calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
tools: Optional list of tools
|
||||||
|
settings: LLM settings
|
||||||
|
stream: Whether to enable streaming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of kwargs for OpenAI API call
|
||||||
|
"""
|
||||||
|
openai_messages = self._convert_messages(messages)
|
||||||
|
is_reasoning = self._is_reasoning_model()
|
||||||
|
|
||||||
|
# Log info for reasoning models on first use
|
||||||
|
if is_reasoning:
|
||||||
|
logger.debug(
|
||||||
|
f"Using reasoning model {self.model}: "
|
||||||
|
"max_completion_tokens will be used, temperature/top_p ignored"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reasoning models (o1) don't support system parameter
|
||||||
|
# System message must be added as a developer message instead
|
||||||
|
if system_prompt:
|
||||||
|
if is_reasoning:
|
||||||
|
# For o1 models, add system prompt as a developer message
|
||||||
|
openai_messages.insert(
|
||||||
|
0, {"role": "developer", "content": system_prompt}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For other models, add as system message
|
||||||
|
openai_messages.insert(0, {"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# Reasoning models use max_completion_tokens instead of max_tokens
|
||||||
|
max_tokens_key = "max_completion_tokens" if is_reasoning else "max_tokens"
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": openai_messages,
|
||||||
|
max_tokens_key: settings.max_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reasoning models don't support temperature or top_p
|
||||||
|
if not is_reasoning:
|
||||||
|
kwargs["temperature"] = settings.temperature
|
||||||
|
kwargs["top_p"] = settings.top_p
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
kwargs["stream"] = True
|
||||||
|
|
||||||
|
if settings.stop_sequences:
|
||||||
|
kwargs["stop"] = settings.stop_sequences
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = self._convert_tools(tools)
|
||||||
|
kwargs["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _parse_and_finalize_tool_call(
|
||||||
|
self, tool_call: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse the accumulated tool call arguments and prepare for yielding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: Tool call dict with 'arguments' field (JSON string)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool call dict with parsed 'input' field (dict)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tool_call["input"] = json.loads(tool_call["arguments"])
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse tool arguments '{tool_call['arguments']}': {e}"
|
||||||
|
)
|
||||||
|
tool_call["input"] = {}
|
||||||
|
del tool_call["arguments"]
|
||||||
|
return tool_call
|
||||||
|
|
||||||
|
def _handle_stream_chunk(
|
||||||
|
self,
|
||||||
|
chunk: Any,
|
||||||
|
current_tool_call: dict[str, Any] | None,
|
||||||
|
) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Handle a single streaming chunk and return events and updated tool state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: Streaming chunk from OpenAI
|
||||||
|
current_tool_call: Current tool call being accumulated (or None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (list of StreamEvents to yield, updated current_tool_call)
|
||||||
|
"""
|
||||||
|
events: list[StreamEvent] = []
|
||||||
|
|
||||||
|
if not chunk.choices:
|
||||||
|
return events, current_tool_call
|
||||||
|
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
|
||||||
|
# Handle text content
|
||||||
|
if delta.content:
|
||||||
|
events.append(StreamEvent(type="text", data=delta.content))
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tool_call in delta.tool_calls:
|
||||||
|
if tool_call.id:
|
||||||
|
# New tool call starting
|
||||||
|
if current_tool_call:
|
||||||
|
# Yield the previous one with parsed input
|
||||||
|
finalized = self._parse_and_finalize_tool_call(
|
||||||
|
current_tool_call
|
||||||
|
)
|
||||||
|
events.append(StreamEvent(type="tool_use", data=finalized))
|
||||||
|
current_tool_call = {
|
||||||
|
"id": tool_call.id,
|
||||||
|
"name": tool_call.function.name or "",
|
||||||
|
"arguments": tool_call.function.arguments or "",
|
||||||
|
}
|
||||||
|
elif current_tool_call and tool_call.function.arguments:
|
||||||
|
# Continue building the current tool call
|
||||||
|
current_tool_call["arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
# Check if stream is finished
|
||||||
|
if chunk.choices[0].finish_reason:
|
||||||
|
if current_tool_call:
|
||||||
|
# Parse the final tool call arguments
|
||||||
|
finalized = self._parse_and_finalize_tool_call(current_tool_call)
|
||||||
|
events.append(StreamEvent(type="tool_use", data=finalized))
|
||||||
|
current_tool_call = None
|
||||||
|
|
||||||
|
return events, current_tool_call
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@ -146,29 +331,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response."""
|
"""Generate a non-streaming response."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(
|
||||||
openai_messages = self._convert_messages(messages)
|
messages, system_prompt, tools, settings, stream=False
|
||||||
|
)
|
||||||
# Add system prompt as first message if provided
|
|
||||||
if system_prompt:
|
|
||||||
openai_messages.insert(
|
|
||||||
0, {"role": "system", "content": system_prompt}
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"model": self.model,
|
|
||||||
"messages": openai_messages,
|
|
||||||
"temperature": settings.temperature,
|
|
||||||
"max_tokens": settings.max_tokens,
|
|
||||||
"top_p": settings.top_p,
|
|
||||||
}
|
|
||||||
|
|
||||||
if settings.stop_sequences:
|
|
||||||
kwargs["stop"] = settings.stop_sequences
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
kwargs["tools"] = self._convert_tools(tools)
|
|
||||||
kwargs["tool_choice"] = "auto"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(**kwargs)
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
@ -180,78 +345,25 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: str | None = None,
|
||||||
tools: Optional[list[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
settings: Optional[LLMSettings] = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> Iterator[StreamEvent]:
|
) -> Iterator[StreamEvent]:
|
||||||
"""Generate a streaming response."""
|
"""Generate a streaming response."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(
|
||||||
openai_messages = self._convert_messages(messages)
|
messages, system_prompt, tools, settings, stream=True
|
||||||
|
)
|
||||||
# Add system prompt as first message if provided
|
|
||||||
if system_prompt:
|
|
||||||
openai_messages.insert(
|
|
||||||
0, {"role": "system", "content": system_prompt}
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"model": self.model,
|
|
||||||
"messages": openai_messages,
|
|
||||||
"temperature": settings.temperature,
|
|
||||||
"max_tokens": settings.max_tokens,
|
|
||||||
"top_p": settings.top_p,
|
|
||||||
"stream": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
if settings.stop_sequences:
|
|
||||||
kwargs["stop"] = settings.stop_sequences
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
kwargs["tools"] = self._convert_tools(tools)
|
|
||||||
kwargs["tool_choice"] = "auto"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = self.client.chat.completions.create(**kwargs)
|
stream = self.client.chat.completions.create(**kwargs)
|
||||||
|
current_tool_call: dict[str, Any] | None = None
|
||||||
current_tool_call: Optional[dict[str, Any]] = None
|
|
||||||
|
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
if not chunk.choices:
|
events, current_tool_call = self._handle_stream_chunk(
|
||||||
continue
|
chunk, current_tool_call
|
||||||
|
)
|
||||||
delta = chunk.choices[0].delta
|
yield from events
|
||||||
|
|
||||||
# Handle text content
|
|
||||||
if delta.content:
|
|
||||||
yield StreamEvent(type="text", data=delta.content)
|
|
||||||
|
|
||||||
# Handle tool calls
|
|
||||||
if delta.tool_calls:
|
|
||||||
for tool_call in delta.tool_calls:
|
|
||||||
if tool_call.id:
|
|
||||||
# New tool call starting
|
|
||||||
if current_tool_call:
|
|
||||||
# Yield the previous one
|
|
||||||
yield StreamEvent(
|
|
||||||
type="tool_use", data=current_tool_call
|
|
||||||
)
|
|
||||||
current_tool_call = {
|
|
||||||
"id": tool_call.id,
|
|
||||||
"name": tool_call.function.name or "",
|
|
||||||
"arguments": tool_call.function.arguments or "",
|
|
||||||
}
|
|
||||||
elif current_tool_call and tool_call.function.arguments:
|
|
||||||
# Continue building the current tool call
|
|
||||||
current_tool_call["arguments"] += (
|
|
||||||
tool_call.function.arguments
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if stream is finished
|
|
||||||
if chunk.choices[0].finish_reason:
|
|
||||||
if current_tool_call:
|
|
||||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
|
||||||
current_tool_call = None
|
|
||||||
|
|
||||||
yield StreamEvent(type="done")
|
yield StreamEvent(type="done")
|
||||||
|
|
||||||
@ -268,35 +380,12 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response asynchronously."""
|
"""Generate a non-streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(
|
||||||
# Use async client
|
messages, system_prompt, tools, settings, stream=False
|
||||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
)
|
||||||
|
|
||||||
openai_messages = self._convert_messages(messages)
|
|
||||||
|
|
||||||
# Add system prompt as first message if provided
|
|
||||||
if system_prompt:
|
|
||||||
openai_messages.insert(
|
|
||||||
0, {"role": "system", "content": system_prompt}
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"model": self.model,
|
|
||||||
"messages": openai_messages,
|
|
||||||
"temperature": settings.temperature,
|
|
||||||
"max_tokens": settings.max_tokens,
|
|
||||||
"top_p": settings.top_p,
|
|
||||||
}
|
|
||||||
|
|
||||||
if settings.stop_sequences:
|
|
||||||
kwargs["stop"] = settings.stop_sequences
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
kwargs["tools"] = self._convert_tools(tools)
|
|
||||||
kwargs["tool_choice"] = "auto"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await async_client.chat.completions.create(**kwargs)
|
response = await self.async_client.chat.completions.create(**kwargs)
|
||||||
return response.choices[0].message.content or ""
|
return response.choices[0].message.content or ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OpenAI API error: {e}")
|
logger.error(f"OpenAI API error: {e}")
|
||||||
@ -311,75 +400,20 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
) -> AsyncIterator[StreamEvent]:
|
) -> AsyncIterator[StreamEvent]:
|
||||||
"""Generate a streaming response asynchronously."""
|
"""Generate a streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(
|
||||||
# Use async client
|
messages, system_prompt, tools, settings, stream=True
|
||||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
)
|
||||||
|
|
||||||
openai_messages = self._convert_messages(messages)
|
|
||||||
|
|
||||||
# Add system prompt as first message if provided
|
|
||||||
if system_prompt:
|
|
||||||
openai_messages.insert(
|
|
||||||
0, {"role": "system", "content": system_prompt}
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"model": self.model,
|
|
||||||
"messages": openai_messages,
|
|
||||||
"temperature": settings.temperature,
|
|
||||||
"max_tokens": settings.max_tokens,
|
|
||||||
"top_p": settings.top_p,
|
|
||||||
"stream": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
if settings.stop_sequences:
|
|
||||||
kwargs["stop"] = settings.stop_sequences
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
kwargs["tools"] = self._convert_tools(tools)
|
|
||||||
kwargs["tool_choice"] = "auto"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = await async_client.chat.completions.create(**kwargs)
|
stream = await self.async_client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
current_tool_call: Optional[dict[str, Any]] = None
|
current_tool_call: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
if not chunk.choices:
|
events, current_tool_call = self._handle_stream_chunk(
|
||||||
continue
|
chunk, current_tool_call
|
||||||
|
)
|
||||||
delta = chunk.choices[0].delta
|
for event in events:
|
||||||
|
yield event
|
||||||
# Handle text content
|
|
||||||
if delta.content:
|
|
||||||
yield StreamEvent(type="text", data=delta.content)
|
|
||||||
|
|
||||||
# Handle tool calls
|
|
||||||
if delta.tool_calls:
|
|
||||||
for tool_call in delta.tool_calls:
|
|
||||||
if tool_call.id:
|
|
||||||
# New tool call starting
|
|
||||||
if current_tool_call:
|
|
||||||
# Yield the previous one
|
|
||||||
yield StreamEvent(
|
|
||||||
type="tool_use", data=current_tool_call
|
|
||||||
)
|
|
||||||
current_tool_call = {
|
|
||||||
"id": tool_call.id,
|
|
||||||
"name": tool_call.function.name or "",
|
|
||||||
"arguments": tool_call.function.arguments or "",
|
|
||||||
}
|
|
||||||
elif current_tool_call and tool_call.function.arguments:
|
|
||||||
# Continue building the current tool call
|
|
||||||
current_tool_call["arguments"] += (
|
|
||||||
tool_call.function.arguments
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if stream is finished
|
|
||||||
if chunk.choices[0].finish_reason:
|
|
||||||
if current_tool_call:
|
|
||||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
|
||||||
current_tool_call = None
|
|
||||||
|
|
||||||
yield StreamEvent(type="done")
|
yield StreamEvent(type="done")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user