mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
basic tools
This commit is contained in:
parent
798b4779da
commit
64e84b1c89
@ -113,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
"messages": anthropic_messages,
|
"messages": anthropic_messages,
|
||||||
"temperature": settings.temperature,
|
"temperature": settings.temperature,
|
||||||
"max_tokens": settings.max_tokens,
|
"max_tokens": settings.max_tokens,
|
||||||
|
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only include top_p if explicitly set
|
# Only include top_p if explicitly set
|
||||||
@ -152,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
||||||
"""
|
"""
|
||||||
event_type = getattr(event, "type", None)
|
event_type = getattr(event, "type", None)
|
||||||
|
|
||||||
# Handle error events
|
# Handle error events
|
||||||
if event_type == "error":
|
if event_type == "error":
|
||||||
error = getattr(event, "error", None)
|
error = getattr(event, "error", None)
|
||||||
|
|||||||
@ -422,7 +422,11 @@ class BaseLLMProvider(ABC):
|
|||||||
"""Convert tool definitions to provider format."""
|
"""Convert tool definitions to provider format."""
|
||||||
if not tools:
|
if not tools:
|
||||||
return None
|
return None
|
||||||
return [self._convert_tool(tool) for tool in tools]
|
converted = [
|
||||||
|
tool.provider_format(self.provider) or self._convert_tool(tool)
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
return [c for c in converted if c is not None]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(
|
def generate(
|
||||||
|
|||||||
@ -151,32 +151,24 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
return openai_messages
|
return openai_messages
|
||||||
|
|
||||||
def _convert_tools(
|
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
|
||||||
self, tools: list[ToolDefinition] | None
|
|
||||||
) -> list[dict[str, Any]] | None:
|
|
||||||
"""
|
"""
|
||||||
Convert our tool definitions to OpenAI format.
|
Convert our tool definitions to OpenAI format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tools: List of tool definitions
|
tool: Tool definition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tools in OpenAI format
|
Tool in OpenAI format
|
||||||
"""
|
"""
|
||||||
if not tools:
|
return {
|
||||||
return None
|
"type": "function",
|
||||||
|
"function": {
|
||||||
return [
|
"name": tool.name,
|
||||||
{
|
"description": tool.description,
|
||||||
"type": "function",
|
"parameters": tool.input_schema,
|
||||||
"function": {
|
},
|
||||||
"name": tool.name,
|
}
|
||||||
"description": tool.description,
|
|
||||||
"parameters": tool.input_schema,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tool in tools
|
|
||||||
]
|
|
||||||
|
|
||||||
def _build_request_kwargs(
|
def _build_request_kwargs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -34,3 +34,6 @@ class ToolDefinition:
|
|||||||
|
|
||||||
def __call__(self, input: ToolInput) -> str:
|
def __call__(self, input: ToolInput) -> str:
|
||||||
return self.function(input)
|
return self.function(input)
|
||||||
|
|
||||||
|
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
return None
|
||||||
|
|||||||
36
src/memory/common/llms/tools/base.py
Normal file
36
src/memory/common/llms/tools/base.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any
|
||||||
|
from memory.common.llms.tools import ToolDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchTool(ToolDefinition):
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
defaults = {
|
||||||
|
"name": "web_search",
|
||||||
|
"description": "Search the web for information",
|
||||||
|
"input_schema": {},
|
||||||
|
"function": lambda input: "result",
|
||||||
|
}
|
||||||
|
super().__init__(**(defaults | kwargs))
|
||||||
|
|
||||||
|
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
if provider == "openai":
|
||||||
|
return {"type": "web_search"}
|
||||||
|
if provider == "anthropic":
|
||||||
|
return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class WebFetchTool(ToolDefinition):
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
defaults = {
|
||||||
|
"name": "web_fetch",
|
||||||
|
"description": "Fetch the contents of a web page",
|
||||||
|
"input_schema": {},
|
||||||
|
"function": lambda input: "result",
|
||||||
|
}
|
||||||
|
super().__init__(**(defaults | kwargs))
|
||||||
|
|
||||||
|
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
if provider == "anthropic":
|
||||||
|
return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10}
|
||||||
|
return None
|
||||||
@ -242,7 +242,7 @@ def call_llm(
|
|||||||
|
|
||||||
user_id = None
|
user_id = None
|
||||||
if from_user and not channel:
|
if from_user and not channel:
|
||||||
user_id = from_user.id
|
user_id = cast(int, from_user.id)
|
||||||
prev_messages = previous_messages(
|
prev_messages = previous_messages(
|
||||||
session,
|
session,
|
||||||
user_id,
|
user_id,
|
||||||
@ -251,8 +251,10 @@ def call_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
from memory.common.llms.tools.discord import make_discord_tools
|
from memory.common.llms.tools.discord import make_discord_tools
|
||||||
|
from memory.common.llms.tools.base import WebSearchTool
|
||||||
|
|
||||||
tools = make_discord_tools(bot_user, from_user, channel, model=model)
|
tools = make_discord_tools(bot_user, from_user, channel, model=model)
|
||||||
|
tools |= {"web_search": WebSearchTool()}
|
||||||
|
|
||||||
# Filter to allowed tools if specified
|
# Filter to allowed tools if specified
|
||||||
if allowed_tools is not None:
|
if allowed_tools is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user