basic tools

This commit is contained in:
Daniel O'Connell 2025-11-02 16:34:38 +01:00
parent 798b4779da
commit 64e84b1c89
6 changed files with 59 additions and 22 deletions

View File

@ -113,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider):
"messages": anthropic_messages, "messages": anthropic_messages,
"temperature": settings.temperature, "temperature": settings.temperature,
"max_tokens": settings.max_tokens, "max_tokens": settings.max_tokens,
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
} }
# Only include top_p if explicitly set # Only include top_p if explicitly set
@ -152,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider):
Tuple of (StreamEvent or None, updated current_tool_use or None) Tuple of (StreamEvent or None, updated current_tool_use or None)
""" """
event_type = getattr(event, "type", None) event_type = getattr(event, "type", None)
# Handle error events # Handle error events
if event_type == "error": if event_type == "error":
error = getattr(event, "error", None) error = getattr(event, "error", None)

View File

@ -422,7 +422,11 @@ class BaseLLMProvider(ABC):
"""Convert tool definitions to provider format.""" """Convert tool definitions to provider format."""
if not tools: if not tools:
return None return None
return [self._convert_tool(tool) for tool in tools] converted = [
tool.provider_format(self.provider) or self._convert_tool(tool)
for tool in tools
]
return [c for c in converted if c is not None]
@abstractmethod @abstractmethod
def generate( def generate(

View File

@ -151,32 +151,24 @@ class OpenAIProvider(BaseLLMProvider):
return openai_messages return openai_messages
def _convert_tools( def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
self, tools: list[ToolDefinition] | None
) -> list[dict[str, Any]] | None:
""" """
Convert our tool definitions to OpenAI format. Convert our tool definitions to OpenAI format.
Args: Args:
tools: List of tool definitions tool: Tool definition
Returns: Returns:
List of tools in OpenAI format Tool in OpenAI format
""" """
if not tools: return {
return None "type": "function",
"function": {
return [ "name": tool.name,
{ "description": tool.description,
"type": "function", "parameters": tool.input_schema,
"function": { },
"name": tool.name, }
"description": tool.description,
"parameters": tool.input_schema,
},
}
for tool in tools
]
def _build_request_kwargs( def _build_request_kwargs(
self, self,

View File

@ -34,3 +34,6 @@ class ToolDefinition:
def __call__(self, input: ToolInput) -> str: def __call__(self, input: ToolInput) -> str:
return self.function(input) return self.function(input)
def provider_format(self, provider: str) -> dict[str, Any] | None:
return None

View File

@ -0,0 +1,36 @@
from typing import Any
from memory.common.llms.tools import ToolDefinition
class WebSearchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_search",
"description": "Search the web for information",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "openai":
return {"type": "web_search"}
if provider == "anthropic":
return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10}
return None
class WebFetchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_fetch",
"description": "Fetch the contents of a web page",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "anthropic":
return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10}
return None

View File

@ -242,7 +242,7 @@ def call_llm(
user_id = None user_id = None
if from_user and not channel: if from_user and not channel:
user_id = from_user.id user_id = cast(int, from_user.id)
prev_messages = previous_messages( prev_messages = previous_messages(
session, session,
user_id, user_id,
@ -251,8 +251,10 @@ def call_llm(
) )
from memory.common.llms.tools.discord import make_discord_tools from memory.common.llms.tools.discord import make_discord_tools
from memory.common.llms.tools.base import WebSearchTool
tools = make_discord_tools(bot_user, from_user, channel, model=model) tools = make_discord_tools(bot_user, from_user, channel, model=model)
tools |= {"web_search": WebSearchTool()}
# Filter to allowed tools if specified # Filter to allowed tools if specified
if allowed_tools is not None: if allowed_tools is not None: