mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
basic tools
This commit is contained in:
parent
798b4779da
commit
64e84b1c89
@ -113,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
"messages": anthropic_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
|
||||
}
|
||||
|
||||
# Only include top_p if explicitly set
|
||||
@ -152,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
||||
"""
|
||||
event_type = getattr(event, "type", None)
|
||||
|
||||
# Handle error events
|
||||
if event_type == "error":
|
||||
error = getattr(event, "error", None)
|
||||
|
||||
@ -422,7 +422,11 @@ class BaseLLMProvider(ABC):
|
||||
"""Convert tool definitions to provider format."""
|
||||
if not tools:
|
||||
return None
|
||||
return [self._convert_tool(tool) for tool in tools]
|
||||
converted = [
|
||||
tool.provider_format(self.provider) or self._convert_tool(tool)
|
||||
for tool in tools
|
||||
]
|
||||
return [c for c in converted if c is not None]
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
|
||||
@ -151,32 +151,24 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
|
||||
return openai_messages
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: list[ToolDefinition] | None
|
||||
) -> list[dict[str, Any]] | None:
|
||||
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
|
||||
"""
|
||||
Convert our tool definitions to OpenAI format.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
tool: Tool definition
|
||||
|
||||
Returns:
|
||||
List of tools in OpenAI format
|
||||
Tool in OpenAI format
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
|
||||
def _build_request_kwargs(
|
||||
self,
|
||||
|
||||
@ -34,3 +34,6 @@ class ToolDefinition:
|
||||
|
||||
def __call__(self, input: ToolInput) -> str:
|
||||
return self.function(input)
|
||||
|
||||
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
36
src/memory/common/llms/tools/base.py
Normal file
36
src/memory/common/llms/tools/base.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import Any
|
||||
from memory.common.llms.tools import ToolDefinition
|
||||
|
||||
|
||||
class WebSearchTool(ToolDefinition):
|
||||
def __init__(self, **kwargs: Any):
|
||||
defaults = {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information",
|
||||
"input_schema": {},
|
||||
"function": lambda input: "result",
|
||||
}
|
||||
super().__init__(**(defaults | kwargs))
|
||||
|
||||
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||
if provider == "openai":
|
||||
return {"type": "web_search"}
|
||||
if provider == "anthropic":
|
||||
return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10}
|
||||
return None
|
||||
|
||||
|
||||
class WebFetchTool(ToolDefinition):
|
||||
def __init__(self, **kwargs: Any):
|
||||
defaults = {
|
||||
"name": "web_fetch",
|
||||
"description": "Fetch the contents of a web page",
|
||||
"input_schema": {},
|
||||
"function": lambda input: "result",
|
||||
}
|
||||
super().__init__(**(defaults | kwargs))
|
||||
|
||||
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||
if provider == "anthropic":
|
||||
return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10}
|
||||
return None
|
||||
@ -242,7 +242,7 @@ def call_llm(
|
||||
|
||||
user_id = None
|
||||
if from_user and not channel:
|
||||
user_id = from_user.id
|
||||
user_id = cast(int, from_user.id)
|
||||
prev_messages = previous_messages(
|
||||
session,
|
||||
user_id,
|
||||
@ -251,8 +251,10 @@ def call_llm(
|
||||
)
|
||||
|
||||
from memory.common.llms.tools.discord import make_discord_tools
|
||||
from memory.common.llms.tools.base import WebSearchTool
|
||||
|
||||
tools = make_discord_tools(bot_user, from_user, channel, model=model)
|
||||
tools |= {"web_search": WebSearchTool()}
|
||||
|
||||
# Filter to allowed tools if specified
|
||||
if allowed_tools is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user