diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index c36e2bf..a55a997 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -113,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider): "messages": anthropic_messages, "temperature": settings.temperature, "max_tokens": settings.max_tokens, + "extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"}, } # Only include top_p if explicitly set @@ -152,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider): Tuple of (StreamEvent or None, updated current_tool_use or None) """ event_type = getattr(event, "type", None) - # Handle error events if event_type == "error": error = getattr(event, "error", None) diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index a7276a2..e950287 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -422,7 +422,11 @@ class BaseLLMProvider(ABC): """Convert tool definitions to provider format.""" if not tools: return None - return [self._convert_tool(tool) for tool in tools] + converted = [ + tool.provider_format(self.provider) or self._convert_tool(tool) + for tool in tools + ] + return [c for c in converted if c is not None] @abstractmethod def generate( diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py index 3548beb..8a342cd 100644 --- a/src/memory/common/llms/openai_provider.py +++ b/src/memory/common/llms/openai_provider.py @@ -151,32 +151,24 @@ class OpenAIProvider(BaseLLMProvider): return openai_messages - def _convert_tools( - self, tools: list[ToolDefinition] | None - ) -> list[dict[str, Any]] | None: + def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]: """ Convert our tool definitions to OpenAI format. Args: - tools: List of tool definitions + tool: Tool definition Returns: - List of tools in OpenAI format + Tool in OpenAI format """ - if not tools: - return None - - return [ - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.input_schema, - }, - } - for tool in tools - ] + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + }, + } def _build_request_kwargs( self, diff --git a/src/memory/common/llms/tools/__init__.py b/src/memory/common/llms/tools/__init__.py index 7a4e1c0..f716e20 100644 --- a/src/memory/common/llms/tools/__init__.py +++ b/src/memory/common/llms/tools/__init__.py @@ -34,3 +34,6 @@ class ToolDefinition: def __call__(self, input: ToolInput) -> str: return self.function(input) + + def provider_format(self, provider: str) -> dict[str, Any] | None: + return None diff --git a/src/memory/common/llms/tools/base.py b/src/memory/common/llms/tools/base.py new file mode 100644 index 0000000..d8f0339 --- /dev/null +++ b/src/memory/common/llms/tools/base.py @@ -0,0 +1,36 @@ +from typing import Any +from memory.common.llms.tools import ToolDefinition + + +class WebSearchTool(ToolDefinition): + def __init__(self, **kwargs: Any): + defaults = { + "name": "web_search", + "description": "Search the web for information", + "input_schema": {}, + "function": lambda input: "result", + } + super().__init__(**(defaults | kwargs)) + + def provider_format(self, provider: str) -> dict[str, Any] | None: + if provider == "openai": + return {"type": "web_search"} + if provider == "anthropic": + return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10} + return None + + +class WebFetchTool(ToolDefinition): + def __init__(self, **kwargs: Any): + defaults = { + "name": "web_fetch", + "description": "Fetch the contents of a web page", + "input_schema": {}, + "function": lambda input: "result", + } + super().__init__(**(defaults | kwargs)) + + def provider_format(self, provider: str) -> dict[str, Any] | None: + if provider == "anthropic": + return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10} + return None diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py index 3f4d254..ae061a3 100644 --- a/src/memory/discord/messages.py +++ b/src/memory/discord/messages.py @@ -242,7 +242,7 @@ def call_llm( user_id = None if from_user and not channel: - user_id = from_user.id + user_id = cast(int, from_user.id) prev_messages = previous_messages( session, user_id, @@ -251,8 +251,10 @@ def call_llm( ) from memory.common.llms.tools.discord import make_discord_tools + from memory.common.llms.tools.base import WebSearchTool tools = make_discord_tools(bot_user, from_user, channel, model=model) + tools |= {"web_search": WebSearchTool()} # Filter to allowed tools if specified if allowed_tools is not None: