basic tools

This commit is contained in:
Daniel O'Connell 2025-11-02 16:34:38 +01:00
parent 798b4779da
commit 64e84b1c89
6 changed files with 59 additions and 22 deletions

View File

@ -113,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider):
"messages": anthropic_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
}
# Only include top_p if explicitly set
@ -152,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider):
Tuple of (StreamEvent or None, updated current_tool_use or None)
"""
event_type = getattr(event, "type", None)
# Handle error events
if event_type == "error":
error = getattr(event, "error", None)

View File

@ -422,7 +422,11 @@ class BaseLLMProvider(ABC):
"""Convert tool definitions to provider format."""
if not tools:
return None
return [self._convert_tool(tool) for tool in tools]
converted = [
tool.provider_format(self.provider) or self._convert_tool(tool)
for tool in tools
]
return [c for c in converted if c is not None]
@abstractmethod
def generate(

View File

@ -151,32 +151,24 @@ class OpenAIProvider(BaseLLMProvider):
return openai_messages
def _convert_tools(
self, tools: list[ToolDefinition] | None
) -> list[dict[str, Any]] | None:
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
"""
Convert our tool definitions to OpenAI format.
Args:
tools: List of tool definitions
tool: Tool definition
Returns:
List of tools in OpenAI format
Tool in OpenAI format
"""
if not tools:
return None
return [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
},
}
for tool in tools
]
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
},
}
def _build_request_kwargs(
self,

View File

@ -34,3 +34,6 @@ class ToolDefinition:
def __call__(self, input: ToolInput) -> str:
return self.function(input)
def provider_format(self, provider: str) -> dict[str, Any] | None:
return None

View File

@ -0,0 +1,36 @@
from typing import Any
from memory.common.llms.tools import ToolDefinition
class WebSearchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_search",
"description": "Search the web for information",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "openai":
return {"type": "web_search"}
if provider == "anthropic":
return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10}
return None
class WebFetchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_fetch",
"description": "Fetch the contents of a web page",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "anthropic":
return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10}
return None

View File

@ -242,7 +242,7 @@ def call_llm(
user_id = None
if from_user and not channel:
user_id = from_user.id
user_id = cast(int, from_user.id)
prev_messages = previous_messages(
session,
user_id,
@ -251,8 +251,10 @@ def call_llm(
)
from memory.common.llms.tools.discord import make_discord_tools
from memory.common.llms.tools.base import WebSearchTool
tools = make_discord_tools(bot_user, from_user, channel, model=model)
tools |= {"web_search": WebSearchTool()}
# Filter to allowed tools if specified
if allowed_tools is not None: