mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-22 22:56:38 +02:00
move to general LLM providers
This commit is contained in:
parent
08d17c28dd
commit
99d3843f47
@ -26,9 +26,12 @@ def upgrade() -> None:
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("member_count", sa.Integer(), nullable=True),
|
||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||
sa.Column(
|
||||
"track_messages", sa.Boolean(), server_default="true", nullable=False
|
||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||
),
|
||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
@ -56,7 +59,12 @@ def upgrade() -> None:
|
||||
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("channel_type", sa.Text(), nullable=False),
|
||||
sa.Column("track_messages", sa.Boolean(), nullable=True),
|
||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||
sa.Column(
|
||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||
),
|
||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
@ -84,9 +92,12 @@ def upgrade() -> None:
|
||||
sa.Column("username", sa.Text(), nullable=False),
|
||||
sa.Column("display_name", sa.Text(), nullable=True),
|
||||
sa.Column("system_user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||
sa.Column(
|
||||
"allow_dm_tracking", sa.Boolean(), server_default="true", nullable=False
|
||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||
),
|
||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
|
@ -5,7 +5,7 @@ alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.18.1
|
||||
anthropic==0.69.0
|
||||
openai==1.25.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
|
@ -40,7 +40,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
llms.call,
|
||||
llms.summarize,
|
||||
prompt,
|
||||
settings.RANKER_MODEL,
|
||||
images=images,
|
||||
|
@ -15,6 +15,7 @@ SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||
DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||
|
||||
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
||||
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
||||
|
@ -127,7 +127,15 @@ class EmailAccount(Base):
|
||||
)
|
||||
|
||||
|
||||
class DiscordServer(Base):
|
||||
class MessageProcessor:
|
||||
track_messages = Column(Boolean, nullable=False, server_default="true")
|
||||
ignore_messages = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
|
||||
class DiscordServer(Base, MessageProcessor):
|
||||
"""Discord server configuration and metadata"""
|
||||
|
||||
__tablename__ = "discord_servers"
|
||||
@ -138,7 +146,6 @@ class DiscordServer(Base):
|
||||
member_count = Column(Integer)
|
||||
|
||||
# Collection settings
|
||||
track_messages = Column(Boolean, nullable=False, server_default="true")
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
@ -152,7 +159,7 @@ class DiscordServer(Base):
|
||||
)
|
||||
|
||||
|
||||
class DiscordChannel(Base):
|
||||
class DiscordChannel(Base, MessageProcessor):
|
||||
"""Discord channel metadata and configuration"""
|
||||
|
||||
__tablename__ = "discord_channels"
|
||||
@ -163,7 +170,6 @@ class DiscordChannel(Base):
|
||||
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
|
||||
|
||||
# Collection settings (null = inherit from server)
|
||||
track_messages = Column(Boolean, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
@ -171,7 +177,7 @@ class DiscordChannel(Base):
|
||||
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
|
||||
|
||||
|
||||
class DiscordUser(Base):
|
||||
class DiscordUser(Base, MessageProcessor):
|
||||
"""Discord user metadata and preferences"""
|
||||
|
||||
__tablename__ = "discord_users"
|
||||
@ -184,7 +190,6 @@ class DiscordUser(Base):
|
||||
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# Basic DM settings
|
||||
allow_dm_tracking = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
@ -1,122 +0,0 @@
|
||||
import logging
|
||||
import base64
|
||||
import io
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||
"""
|
||||
|
||||
|
||||
def encode_image(image: Image.Image) -> str:
|
||||
"""Encode PIL Image to base64 string."""
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary (for RGBA, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format="JPEG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def call_openai(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
user_content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call_anthropic(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
# Prepare the message content
|
||||
content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
# Add images if provided
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
content.append(
|
||||
{ # type: ignore
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[{"role": "user", "content": content}], # type: ignore
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
if model.startswith("anthropic"):
|
||||
return call_anthropic(prompt, model, images, system_prompt)
|
||||
return call_openai(prompt, model, images, system_prompt)
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
79
src/memory/common/llms/__init__.py
Normal file
79
src/memory/common/llms/__init__.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""LLM provider module for unified LLM access."""
|
||||
|
||||
# Legacy imports for backwards compatibility
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# New provider system
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageContent,
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolDefinition,
|
||||
ToolResultContent,
|
||||
ToolUseContent,
|
||||
create_provider,
|
||||
)
|
||||
from memory.common import tokens
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMProvider",
|
||||
"Message",
|
||||
"MessageRole",
|
||||
"MessageContent",
|
||||
"TextContent",
|
||||
"ImageContent",
|
||||
"ToolUseContent",
|
||||
"ToolResultContent",
|
||||
"ThinkingContent",
|
||||
"ToolDefinition",
|
||||
"StreamEvent",
|
||||
"LLMSettings",
|
||||
"create_provider",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def summarize(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = "",
|
||||
) -> str:
|
||||
provider = create_provider(model=model)
|
||||
try:
|
||||
# Build message content
|
||||
content: list[MessageContent] = [TextContent(text=prompt)]
|
||||
for image in images:
|
||||
content.append(ImageContent(image=image))
|
||||
|
||||
messages = [Message(role=MessageRole.USER, content=content)]
|
||||
settings_obj = LLMSettings(temperature=0.3, max_tokens=2048)
|
||||
|
||||
res = provider.run_with_tools(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt
|
||||
or "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
||||
settings=settings_obj,
|
||||
tools={},
|
||||
)
|
||||
return res.response or ""
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
451
src/memory/common/llms/anthropic_provider.py
Normal file
451
src/memory/common/llms/anthropic_provider.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""Anthropic LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Iterator, Optional
|
||||
|
||||
import anthropic
|
||||
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
ToolDefinition,
|
||||
ToolUseContent,
|
||||
ThinkingContent,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
||||
|
||||
# Models that support extended thinking
|
||||
THINKING_MODELS = {
|
||||
"claude-opus-4",
|
||||
"claude-opus-4-1",
|
||||
"claude-sonnet-4-0",
|
||||
"claude-sonnet-3-7",
|
||||
"claude-sonnet-4-5",
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, model: str, enable_thinking: bool = False):
|
||||
"""
|
||||
Initialize the Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model identifier
|
||||
enable_thinking: Enable extended thinking for supported models
|
||||
"""
|
||||
super().__init__(api_key, model)
|
||||
self.enable_thinking = enable_thinking
|
||||
self._async_client: Optional[anthropic.AsyncAnthropic] = None
|
||||
|
||||
def _initialize_client(self) -> anthropic.Anthropic:
|
||||
"""Initialize the Anthropic client."""
|
||||
return anthropic.Anthropic(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
def async_client(self) -> anthropic.AsyncAnthropic:
|
||||
"""Lazy-load the async client."""
|
||||
if self._async_client is None:
|
||||
self._async_client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
||||
return self._async_client
|
||||
|
||||
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||
"""Convert ImageContent to Anthropic's base64 source format."""
|
||||
encoded_image = self.encode_image(content.image)
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_message(self, message: Message) -> dict[str, Any]:
|
||||
converted = message.to_dict()
|
||||
if converted["role"] == MessageRole.ASSISTANT and isinstance(
|
||||
converted["content"], list
|
||||
):
|
||||
content = sorted(
|
||||
converted["content"], key=lambda x: x["type"] != "thinking"
|
||||
)
|
||||
return converted | {"content": content}
|
||||
return converted
|
||||
|
||||
def _should_include_message(self, message: Message) -> bool:
|
||||
"""Filter out system messages (handled separately in Anthropic)."""
|
||||
return message.role != MessageRole.SYSTEM
|
||||
|
||||
def _supports_thinking(self) -> bool:
|
||||
"""Check if the current model supports extended thinking."""
|
||||
model_lower = self.model.lower()
|
||||
return any(supported in model_lower for supported in self.THINKING_MODELS)
|
||||
|
||||
def _build_request_kwargs(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None,
|
||||
tools: list[ToolDefinition] | None,
|
||||
settings: LLMSettings,
|
||||
) -> dict[str, Any]:
|
||||
"""Build common request kwargs for API calls."""
|
||||
anthropic_messages = self._convert_messages(messages)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": anthropic_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
}
|
||||
|
||||
# Only include top_p if explicitly set
|
||||
if settings.top_p is not None:
|
||||
kwargs["top_p"] = settings.top_p
|
||||
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop_sequences"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
|
||||
# Enable extended thinking if requested and model supports it
|
||||
if self.enable_thinking and self._supports_thinking():
|
||||
thinking_budget = min(10000, settings.max_tokens - 1024)
|
||||
if thinking_budget >= 1024:
|
||||
kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": thinking_budget,
|
||||
}
|
||||
# When thinking is enabled: temperature must be 1, can't use top_p
|
||||
kwargs["temperature"] = 1.0
|
||||
kwargs.pop("top_p", None)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _handle_stream_event(
|
||||
self, event: Any, current_tool_use: dict[str, Any] | None
|
||||
) -> tuple[StreamEvent | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Handle a streaming event and return StreamEvent and updated tool state.
|
||||
|
||||
Returns:
|
||||
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
||||
"""
|
||||
event_type = getattr(event, "type", None)
|
||||
|
||||
# Handle error events
|
||||
if event_type == "error":
|
||||
error = getattr(event, "error", None)
|
||||
error_msg = str(error) if error else "Unknown error"
|
||||
return StreamEvent(type="error", data=error_msg), current_tool_use
|
||||
|
||||
if event_type == "content_block_start":
|
||||
block = getattr(event, "content_block", None)
|
||||
if not block:
|
||||
return None, current_tool_use
|
||||
|
||||
block_type = getattr(block, "type", None)
|
||||
|
||||
# Handle various tool types (tool_use, mcp_tool_use, server_tool_use)
|
||||
if block_type in ("tool_use", "mcp_tool_use", "server_tool_use"):
|
||||
# In content_block_start, input may already be present (empty dict)
|
||||
block_input = getattr(block, "input", None)
|
||||
current_tool_use = {
|
||||
"id": getattr(block, "id", ""),
|
||||
"name": getattr(block, "name", ""),
|
||||
"input": block_input if block_input is not None else "",
|
||||
"server_name": getattr(block, "server_name", None),
|
||||
"is_server_call": block_type != "tool_use",
|
||||
}
|
||||
|
||||
# Handle tool result blocks
|
||||
elif hasattr(block, "tool_use_id"):
|
||||
tool_result = {
|
||||
"id": getattr(block, "tool_use_id", ""),
|
||||
"result": getattr(block, "content", ""),
|
||||
}
|
||||
return StreamEvent(
|
||||
type="tool_result", data=tool_result
|
||||
), current_tool_use
|
||||
|
||||
# For non-tool blocks (text, thinking), we don't need to track state
|
||||
return None, current_tool_use
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = getattr(event, "delta", None)
|
||||
if not delta:
|
||||
return None, current_tool_use
|
||||
|
||||
delta_type = getattr(delta, "type", None)
|
||||
|
||||
if delta_type == "text_delta":
|
||||
text = getattr(delta, "text", "")
|
||||
return StreamEvent(type="text", data=text), current_tool_use
|
||||
|
||||
elif delta_type == "thinking_delta":
|
||||
thinking = getattr(delta, "thinking", "")
|
||||
return StreamEvent(type="thinking", data=thinking), current_tool_use
|
||||
|
||||
elif delta_type == "signature_delta":
|
||||
# Handle thinking signature for extended thinking
|
||||
signature = getattr(delta, "signature", "")
|
||||
return StreamEvent(
|
||||
type="thinking", signature=signature
|
||||
), current_tool_use
|
||||
|
||||
elif delta_type == "input_json_delta":
|
||||
if current_tool_use is None:
|
||||
# Edge case: received input_json_delta without tool_use start
|
||||
logger.warning("Received input_json_delta without tool_use context")
|
||||
return None, None
|
||||
|
||||
# Only accumulate if input is still a string (being built up)
|
||||
if isinstance(current_tool_use.get("input"), str):
|
||||
partial_json = getattr(delta, "partial_json", "")
|
||||
current_tool_use["input"] += partial_json
|
||||
# else: input was already set as a dict in content_block_start
|
||||
|
||||
return None, current_tool_use
|
||||
|
||||
elif event_type == "content_block_stop":
|
||||
if current_tool_use:
|
||||
# Use the parsed input from the content block if available
|
||||
# This handles empty inputs {} more reliably than parsing
|
||||
content_block = getattr(event, "content_block", None)
|
||||
if content_block and hasattr(content_block, "input"):
|
||||
current_tool_use["input"] = content_block.input
|
||||
else:
|
||||
# Fallback: parse accumulated JSON string
|
||||
input_str = current_tool_use.get("input", "")
|
||||
if isinstance(input_str, str):
|
||||
# Need to parse the accumulated string
|
||||
if not input_str or input_str.isspace():
|
||||
# Empty or whitespace-only input
|
||||
current_tool_use["input"] = {}
|
||||
else:
|
||||
try:
|
||||
current_tool_use["input"] = json.loads(input_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to parse tool input '{input_str}': {e}"
|
||||
)
|
||||
current_tool_use["input"] = {}
|
||||
# else: input is already parsed
|
||||
|
||||
tool_data = {
|
||||
"id": current_tool_use.get("id", ""),
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"input": current_tool_use.get("input", {}),
|
||||
}
|
||||
# Include server info if present
|
||||
if current_tool_use.get("server_name"):
|
||||
tool_data["server_name"] = current_tool_use["server_name"]
|
||||
if current_tool_use.get("is_server_call"):
|
||||
tool_data["is_server_call"] = current_tool_use["is_server_call"]
|
||||
|
||||
return StreamEvent(type="tool_use", data=tool_data), None
|
||||
|
||||
elif event_type == "message_delta":
|
||||
delta = getattr(event, "delta", None)
|
||||
if delta:
|
||||
stop_reason = getattr(delta, "stop_reason", None)
|
||||
if stop_reason == "max_tokens":
|
||||
return StreamEvent(
|
||||
type="error", data="Max tokens reached"
|
||||
), current_tool_use
|
||||
|
||||
# Handle token usage information
|
||||
usage = getattr(event, "usage", None)
|
||||
if usage:
|
||||
usage_data = {
|
||||
"input_tokens": getattr(usage, "input_tokens", 0),
|
||||
"output_tokens": getattr(usage, "output_tokens", 0),
|
||||
"cache_creation_input_tokens": getattr(
|
||||
usage, "cache_creation_input_tokens", None
|
||||
),
|
||||
"cache_read_input_tokens": getattr(
|
||||
usage, "cache_read_input_tokens", None
|
||||
),
|
||||
}
|
||||
# Could emit this as a separate event type if needed
|
||||
logger.debug(f"Token usage: {usage_data}")
|
||||
|
||||
return None, current_tool_use
|
||||
|
||||
elif event_type == "message_stop":
|
||||
# Final event - clean up any pending state
|
||||
if current_tool_use:
|
||||
logger.warning(
|
||||
f"Message ended with incomplete tool use: {current_tool_use}"
|
||||
)
|
||||
return StreamEvent(type="done"), None
|
||||
|
||||
# Unknown event type - log but don't fail
|
||||
if event_type and event_type not in (
|
||||
"message_start",
|
||||
"message_delta",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"message_stop",
|
||||
):
|
||||
logger.debug(f"Unknown event type: {event_type}")
|
||||
|
||||
return None, current_tool_use
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
response = self.client.messages.create(**kwargs)
|
||||
return "".join(
|
||||
block.text for block in response.content if block.type == "text"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
with self.client.messages.stream(**kwargs) as stream:
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
|
||||
for event in stream:
|
||||
stream_event, current_tool_use = self._handle_stream_event(
|
||||
event, current_tool_use
|
||||
)
|
||||
if stream_event:
|
||||
yield stream_event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
response = await self.async_client.messages.create(**kwargs)
|
||||
return "".join(
|
||||
block.text for block in response.content if block.type == "text"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
async with self.async_client.messages.stream(**kwargs) as stream:
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
|
||||
async for event in stream:
|
||||
stream_event, current_tool_use = self._handle_stream_event(
|
||||
event, current_tool_use
|
||||
)
|
||||
if stream_event:
|
||||
yield stream_event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
|
||||
def stream_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Iterator[StreamEvent]:
|
||||
if max_iterations <= 0:
|
||||
return
|
||||
|
||||
response = TextContent(text="")
|
||||
thinking = ThinkingContent(thinking="", signature="")
|
||||
|
||||
for event in self.stream(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
tools=list(tools.values()),
|
||||
settings=settings,
|
||||
):
|
||||
if event.type == "text":
|
||||
response.text += event.data
|
||||
yield event
|
||||
elif event.type == "thinking" and event.signature:
|
||||
thinking.signature = event.signature
|
||||
elif event.type == "thinking":
|
||||
thinking.thinking += event.data
|
||||
yield event
|
||||
elif event.type == "tool_use":
|
||||
yield event
|
||||
tool_result = self.execute_tool(event.data, tools)
|
||||
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||
messages.append(
|
||||
Message.assistant(
|
||||
response,
|
||||
thinking,
|
||||
ToolUseContent(
|
||||
id=event.data["id"],
|
||||
name=event.data["name"],
|
||||
input=event.data["input"],
|
||||
),
|
||||
)
|
||||
)
|
||||
messages.append(Message.user(tool_result=tool_result))
|
||||
yield from self.stream_with_tools(
|
||||
messages, tools, settings, system_prompt, max_iterations - 1
|
||||
)
|
||||
elif event.type == "tool_result":
|
||||
yield event
|
||||
elif event.type == "error":
|
||||
logger.error(f"LLM error: {event.data}")
|
||||
raise RuntimeError(f"LLM error: {event.data}")
|
561
src/memory/common/llms/base.py
Normal file
561
src/memory/common/llms/base.py
Normal file
@ -0,0 +1,561 @@
|
||||
"""Base classes and types for LLM providers."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message roles for chat history."""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
"""Text content in a message."""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {"type": "text", "text": self.text}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
"""Image content in a message."""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: Image.Image = None # type: ignore
|
||||
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
# Note: Image will be encoded by provider-specific implementation
|
||||
return {"type": "image", "image": self.image}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUseContent:
|
||||
"""Tool use request from the assistant."""
|
||||
|
||||
type: Literal["tool_use"] = "tool_use"
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
input: dict[str, Any] = None # type: ignore
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"type": "tool_use",
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"input": self.input,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResultContent:
|
||||
"""Tool result from tool execution."""
|
||||
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str = ""
|
||||
content: str = ""
|
||||
is_error: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": self.tool_use_id,
|
||||
"content": self.content,
|
||||
"is_error": self.is_error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThinkingContent:
|
||||
"""Thinking/reasoning content from the assistant (extended thinking)."""
|
||||
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str = ""
|
||||
signature: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"type": "thinking",
|
||||
"thinking": self.thinking,
|
||||
"signature": self.signature,
|
||||
}
|
||||
|
||||
|
||||
MessageContent = Union[
|
||||
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Turn:
|
||||
"""A turn in the conversation."""
|
||||
|
||||
response: str | None
|
||||
thinking: str | None
|
||||
tool_calls: dict[str, ToolResult] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A message in the conversation history."""
|
||||
|
||||
role: MessageRole
|
||||
content: Union[str, list[MessageContent]]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert message to dictionary format."""
|
||||
if isinstance(self.content, str):
|
||||
return {"role": self.role.value, "content": self.content}
|
||||
content_list = [item.to_dict() for item in self.content]
|
||||
return {"role": self.role.value, "content": content_list}
|
||||
|
||||
@staticmethod
|
||||
def assistant(
|
||||
text: TextContent | None = None,
|
||||
thinking: ThinkingContent | None = None,
|
||||
tool_use: ToolUseContent | None = None,
|
||||
) -> "Message":
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(text)
|
||||
if thinking:
|
||||
parts.append(thinking)
|
||||
if tool_use:
|
||||
parts.append(tool_use)
|
||||
return Message(role=MessageRole.ASSISTANT, content=parts)
|
||||
|
||||
@staticmethod
|
||||
def user(
|
||||
text: str | None = None, tool_result: ToolResultContent | None = None
|
||||
) -> "Message":
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(TextContent(text=text))
|
||||
if tool_result:
|
||||
parts.append(tool_result)
|
||||
return Message(role=MessageRole.USER, content=parts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEvent:
|
||||
"""An event from the streaming response."""
|
||||
|
||||
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
|
||||
data: Any = None
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSettings:
|
||||
"""Settings for LLM API calls."""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 2048
|
||||
# Don't set by default - some models don't allow both temp and top_p
|
||||
top_p: float | None = None
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Base class for LLM providers."""
|
||||
|
||||
def __init__(self, api_key: str, model: str):
|
||||
"""
|
||||
Initialize the LLM provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for the provider
|
||||
model: Model identifier
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self._client: Any = None
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_client(self) -> Any:
|
||||
"""Initialize the provider-specific client."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Lazy-load the client."""
|
||||
if self._client is None:
|
||||
self._client = self._initialize_client()
|
||||
return self._client
|
||||
|
||||
def execute_tool(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
tool_handlers: dict[str, ToolDefinition],
|
||||
) -> ToolResultContent:
|
||||
"""
|
||||
Execute a tool call.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call
|
||||
tool_handlers: Dict mapping tool names to handler functions
|
||||
|
||||
Returns:
|
||||
ToolResultContent with result or error
|
||||
"""
|
||||
name = tool_call.get("name")
|
||||
tool_use_id = tool_call.get("id")
|
||||
input = tool_call.get("input")
|
||||
|
||||
if not name:
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content="Tool name missing",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
if not (tool := tool_handlers.get(name)):
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content=f"Tool '{name}' not found",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
try:
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content=tool(input),
|
||||
is_error=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool '{name}' failed: {e}", exc_info=True)
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content=str(e),
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
def encode_image(self, image: Image.Image) -> str:
|
||||
"""
|
||||
Encode PIL Image to base64 string.
|
||||
|
||||
Args:
|
||||
image: PIL Image to encode
|
||||
|
||||
Returns:
|
||||
Base64 encoded string
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary (for RGBA, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format="JPEG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
|
||||
"""Convert TextContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||
"""Convert ImageContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
|
||||
"""Convert ToolUseContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_tool_result_content(
|
||||
self, content: ToolResultContent
|
||||
) -> dict[str, Any]:
|
||||
"""Convert ToolResultContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_thinking_content(self, content: ThinkingContent) -> dict[str, Any]:
|
||||
"""Convert ThinkingContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a MessageContent item to provider format.
|
||||
|
||||
Dispatches to type-specific converters that can be overridden.
|
||||
"""
|
||||
if isinstance(content, TextContent):
|
||||
return self._convert_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
return self._convert_image_content(content)
|
||||
elif isinstance(content, ToolUseContent):
|
||||
return self._convert_tool_use_content(content)
|
||||
elif isinstance(content, ToolResultContent):
|
||||
return self._convert_tool_result_content(content)
|
||||
elif isinstance(content, ThinkingContent):
|
||||
return self._convert_thinking_content(content)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(content)}")
|
||||
|
||||
def _convert_message(self, message: Message) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a Message to provider format.
|
||||
|
||||
Can be overridden for provider-specific handling (e.g., filtering system messages).
|
||||
"""
|
||||
return message.to_dict()
|
||||
|
||||
def _should_include_message(self, message: Message) -> bool:
|
||||
"""
|
||||
Determine if a message should be included in the request.
|
||||
|
||||
Override to filter messages (e.g., Anthropic filters SYSTEM messages).
|
||||
|
||||
Args:
|
||||
message: Message to check
|
||||
|
||||
Returns:
|
||||
True if message should be included
|
||||
"""
|
||||
return True
|
||||
|
||||
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert a list of messages to provider format.
|
||||
|
||||
Uses _should_include_message for filtering and _convert_message for conversion.
|
||||
"""
|
||||
return [
|
||||
self._convert_message(msg)
|
||||
for msg in messages
|
||||
if self._should_include_message(msg)
|
||||
]
|
||||
|
||||
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a single ToolDefinition to provider format.
|
||||
|
||||
Default format matches Anthropic. Override for other providers (e.g., OpenAI uses functions).
|
||||
"""
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.input_schema,
|
||||
}
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: list[ToolDefinition] | None
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""Convert tool definitions to provider format."""
|
||||
if not tools:
|
||||
return None
|
||||
return [self._convert_tool(tool) for tool in tools]
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Yields:
|
||||
StreamEvent objects containing text chunks, tool uses, or errors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a non-streaming response asynchronously.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def astream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""
|
||||
Generate a streaming response asynchronously.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Yields:
|
||||
StreamEvent objects containing text chunks, tool uses, or errors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stream_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Iterator[StreamEvent]:
|
||||
pass
|
||||
|
||||
def run_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Turn:
|
||||
thinking, response, tool_calls = "", "", {}
|
||||
for event in self.stream_with_tools(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
settings=settings,
|
||||
system_prompt=system_prompt,
|
||||
max_iterations=max_iterations,
|
||||
):
|
||||
if event.type == "thinking":
|
||||
thinking += event.data
|
||||
elif event.type == "tool_use":
|
||||
tool_calls[event.data["id"]] = {
|
||||
"name": event.data["name"],
|
||||
"input": event.data["input"],
|
||||
"output": "",
|
||||
}
|
||||
elif event.type == "text":
|
||||
response += event.data
|
||||
elif event.type == "tool_result":
|
||||
current = tool_calls.get(event.data["tool_use_id"]) or {}
|
||||
tool_calls[event.data["tool_use_id"]] = {
|
||||
"name": event.data.get("name") or current.get("name"),
|
||||
"input": event.data.get("input") or current.get("input"),
|
||||
"output": event.data.get("content"),
|
||||
}
|
||||
return Turn(
|
||||
thinking=thinking or None,
|
||||
response=response or None,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
|
||||
def create_provider(
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
enable_thinking: bool = False,
|
||||
) -> BaseLLMProvider:
|
||||
"""
|
||||
Create an LLM provider based on the model name.
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g., "claude-3-opus-20240229", "gpt-4").
|
||||
If not provided, uses SUMMARIZER_MODEL from settings.
|
||||
api_key: Optional API key. If not provided, uses keys from settings.
|
||||
enable_thinking: Enable extended thinking for supported models (Claude Opus 4+, Sonnet 4+, Sonnet 3.7)
|
||||
|
||||
Returns:
|
||||
An initialized LLM provider
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider cannot be determined from the model name
|
||||
"""
|
||||
# Use default model from settings if not provided
|
||||
if model is None:
|
||||
model = settings.SUMMARIZER_MODEL
|
||||
|
||||
provider, model = model.split("/", 1)
|
||||
|
||||
if provider == "anthropic":
|
||||
# Anthropic models
|
||||
if api_key is None:
|
||||
api_key = settings.ANTHROPIC_API_KEY
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_KEY not found in settings. "
|
||||
"Please set it in your .env file."
|
||||
)
|
||||
|
||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||
|
||||
return AnthropicProvider(
|
||||
api_key=api_key, model=model, enable_thinking=enable_thinking
|
||||
)
|
||||
|
||||
# Could add OpenAI support here in the future
|
||||
# elif "gpt" in model_lower or model.startswith("openai"):
|
||||
# ...
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown provider for model: {model}. "
|
||||
f"Supported providers: Anthropic (claude-*)"
|
||||
)
|
388
src/memory/common/llms/openai_provider.py
Normal file
388
src/memory/common/llms/openai_provider.py
Normal file
@ -0,0 +1,388 @@
|
||||
"""OpenAI LLM provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Iterator, Optional
|
||||
|
||||
import openai
|
||||
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageContent,
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolDefinition,
|
||||
ToolResultContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""OpenAI LLM provider with streaming and tool support."""
|
||||
|
||||
def _initialize_client(self) -> openai.OpenAI:
|
||||
"""Initialize the OpenAI client."""
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
|
||||
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert our Message format to OpenAI format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in our format
|
||||
|
||||
Returns:
|
||||
List of messages in OpenAI format
|
||||
"""
|
||||
openai_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
openai_messages.append({"role": msg.role.value, "content": msg.content})
|
||||
else:
|
||||
# Handle multi-part content
|
||||
content_parts = []
|
||||
for item in msg.content:
|
||||
if isinstance(item, TextContent):
|
||||
content_parts.append({"type": "text", "text": item.text})
|
||||
elif isinstance(item, ImageContent):
|
||||
encoded_image = self.encode_image(item.image)
|
||||
image_part: dict[str, Any] = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}"
|
||||
},
|
||||
}
|
||||
if item.detail:
|
||||
image_part["image_url"]["detail"] = item.detail
|
||||
content_parts.append(image_part)
|
||||
elif isinstance(item, ToolUseContent):
|
||||
# OpenAI doesn't have tool_use in content, it's a separate field
|
||||
# We'll handle this by adding a tool_calls field to the message
|
||||
pass
|
||||
elif isinstance(item, ToolResultContent):
|
||||
# OpenAI handles tool results as separate "tool" role messages
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item.tool_use_id,
|
||||
"content": item.content,
|
||||
}
|
||||
)
|
||||
continue
|
||||
elif isinstance(item, ThinkingContent):
|
||||
# OpenAI doesn't have native thinking support in most models
|
||||
# We can add it as text with a special marker
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[Thinking: {item.thinking}]",
|
||||
}
|
||||
)
|
||||
|
||||
# Check if this message has tool calls
|
||||
tool_calls = [
|
||||
item for item in msg.content if isinstance(item, ToolUseContent)
|
||||
]
|
||||
|
||||
message_dict: dict[str, Any] = {"role": msg.role.value}
|
||||
|
||||
if content_parts:
|
||||
message_dict["content"] = content_parts
|
||||
|
||||
if tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {"name": tc.name, "arguments": str(tc.input)},
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
|
||||
openai_messages.append(message_dict)
|
||||
|
||||
return openai_messages
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Optional[list[ToolDefinition]]
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""
|
||||
Convert our tool definitions to OpenAI format.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
List of tools in OpenAI format
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
stream = self.client.chat.completions.create(**kwargs)
|
||||
|
||||
current_tool_call: Optional[dict[str, Any]] = None
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
yield StreamEvent(type="text", data=delta.content)
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.id:
|
||||
# New tool call starting
|
||||
if current_tool_call:
|
||||
# Yield the previous one
|
||||
yield StreamEvent(
|
||||
type="tool_use", data=current_tool_call
|
||||
)
|
||||
current_tool_call = {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name or "",
|
||||
"arguments": tool_call.function.arguments or "",
|
||||
}
|
||||
elif current_tool_call and tool_call.function.arguments:
|
||||
# Continue building the current tool call
|
||||
current_tool_call["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Check if stream is finished
|
||||
if chunk.choices[0].finish_reason:
|
||||
if current_tool_call:
|
||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||
current_tool_call = None
|
||||
|
||||
yield StreamEvent(type="done")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
# Use async client
|
||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = await async_client.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
# Use async client
|
||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
stream = await async_client.chat.completions.create(**kwargs)
|
||||
|
||||
current_tool_call: Optional[dict[str, Any]] = None
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
yield StreamEvent(type="text", data=delta.content)
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.id:
|
||||
# New tool call starting
|
||||
if current_tool_call:
|
||||
# Yield the previous one
|
||||
yield StreamEvent(
|
||||
type="tool_use", data=current_tool_call
|
||||
)
|
||||
current_tool_call = {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name or "",
|
||||
"arguments": tool_call.function.arguments or "",
|
||||
}
|
||||
elif current_tool_call and tool_call.function.arguments:
|
||||
# Continue building the current tool call
|
||||
current_tool_call["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Check if stream is finished
|
||||
if chunk.choices[0].finish_reason:
|
||||
if current_tool_call:
|
||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||
current_tool_call = None
|
||||
|
||||
yield StreamEvent(type="done")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
36
src/memory/common/llms/tools/__init__.py
Normal file
36
src/memory/common/llms/tools/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, TypedDict
|
||||
|
||||
|
||||
ToolInput = str | dict[str, Any] | None
|
||||
ToolHandler = Callable[[ToolInput], str]
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
"""A call to a tool."""
|
||||
|
||||
name: str
|
||||
id: str
|
||||
input: ToolInput
|
||||
|
||||
|
||||
class ToolResult(TypedDict):
|
||||
"""A result from a tool call."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
input: ToolInput
|
||||
output: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""Definition of a tool that can be called by the LLM."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any] # JSON Schema for the tool's parameters
|
||||
function: ToolHandler
|
||||
|
||||
def __call__(self, input: ToolInput) -> str:
|
||||
return self.function(input)
|
42
src/memory/common/llms/tools/ping.py
Normal file
42
src/memory/common/llms/tools/ping.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Ping tool for testing LLM tool integration."""
|
||||
|
||||
from memory.common.llms.tools import ToolDefinition, ToolInput
|
||||
|
||||
|
||||
def handle_ping_call(message: ToolInput = None) -> str:
|
||||
"""
|
||||
Handle a ping tool call.
|
||||
|
||||
Args:
|
||||
message: Optional message to include in response
|
||||
|
||||
Returns:
|
||||
Response string
|
||||
"""
|
||||
if message:
|
||||
return f"pong: {message}"
|
||||
return "pong"
|
||||
|
||||
|
||||
def get_ping_tool() -> ToolDefinition:
|
||||
"""
|
||||
Get a ping tool definition for testing tool calls.
|
||||
|
||||
Returns a simple tool that takes no required parameters and can be used
|
||||
to verify that tool calling is working correctly.
|
||||
"""
|
||||
return ToolDefinition(
|
||||
name="ping",
|
||||
description="A simple test tool that returns 'pong'. Use this to verify tool calling is working.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Optional message to echo back",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
function=handle_ping_call,
|
||||
)
|
9
src/memory/common/messages.py
Normal file
9
src/memory/common/messages.py
Normal file
@ -0,0 +1,9 @@
|
||||
def process_message(
|
||||
msg: str,
|
||||
history: list[str],
|
||||
model: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
disallowed_tools: list[str] | None = None,
|
||||
) -> str:
|
||||
return "asd"
|
@ -172,6 +172,7 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||
)
|
||||
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
||||
|
||||
# Discord collector settings
|
||||
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
|
||||
|
@ -105,7 +105,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||
|
||||
try:
|
||||
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||
response = llms.summarize(prompt, settings.SUMMARIZER_MODEL)
|
||||
result = parse_response(response)
|
||||
|
||||
summary = result.get("summary", "")
|
||||
|
@ -160,7 +160,7 @@ def should_track_message(
|
||||
return False
|
||||
|
||||
if channel.channel_type in ("dm", "group_dm"):
|
||||
return bool(user.allow_dm_tracking)
|
||||
return bool(user.track_messages)
|
||||
|
||||
# Default: track the message
|
||||
return True
|
||||
|
@ -16,7 +16,11 @@ from memory.workers.tasks.content_processing import (
|
||||
create_task_result,
|
||||
process_content_item,
|
||||
)
|
||||
from memory.common.celery_app import ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE
|
||||
from memory.common.celery_app import (
|
||||
ADD_DISCORD_MESSAGE,
|
||||
EDIT_DISCORD_MESSAGE,
|
||||
PROCESS_DISCORD_MESSAGE,
|
||||
)
|
||||
from memory.common import settings
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
@ -40,6 +44,41 @@ def get_prev(
|
||||
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||
|
||||
|
||||
def should_process(message: DiscordMessage) -> bool:
|
||||
return (
|
||||
settings.DISCORD_PROCESS_MESSAGES
|
||||
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
||||
and not (
|
||||
(message.server and message.server.ignore_messages)
|
||||
or (message.channel and message.channel.ignore_messages)
|
||||
or (message.discord_user and message.discord_user.ignore_messages)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||
@safe_task_execution
|
||||
def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
logger.info(f"Processing Discord message {message_id}")
|
||||
|
||||
with make_session() as session:
|
||||
discord_message = session.query(DiscordMessage).get(message_id)
|
||||
if not discord_message:
|
||||
logger.info(f"Discord message not found: {message_id}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Message not found",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
print("Processing message", discord_message.id, discord_message.content)
|
||||
|
||||
return {
|
||||
"status": "processed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=ADD_DISCORD_MESSAGE)
|
||||
@safe_task_execution
|
||||
def add_discord_message(
|
||||
@ -87,11 +126,8 @@ def add_discord_message(
|
||||
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
|
||||
|
||||
result = process_content_item(discord_message, session)
|
||||
|
||||
logger.info(
|
||||
f"Discord message ID after process_content_item: {discord_message.id}"
|
||||
)
|
||||
logger.info(f"Process result: {result}")
|
||||
if should_process(discord_message):
|
||||
process_discord_message.delay(discord_message.id)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -88,11 +88,10 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
||||
|
||||
# Make the LLM call
|
||||
if scheduled_call.model:
|
||||
response = llms.call(
|
||||
response = llms.summarize(
|
||||
prompt=cast(str, scheduled_call.message),
|
||||
model=cast(str, scheduled_call.model),
|
||||
system_prompt=cast(str, scheduled_call.system_prompt)
|
||||
or llms.SYSTEM_PROMPT,
|
||||
system_prompt=cast(str, scheduled_call.system_prompt),
|
||||
)
|
||||
else:
|
||||
response = cast(str, scheduled_call.message)
|
||||
|
@ -273,6 +273,27 @@ def mock_anthropic_client():
|
||||
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
||||
client = mock_client()
|
||||
client.messages = Mock()
|
||||
|
||||
# Mock stream as a context manager
|
||||
mock_stream = Mock()
|
||||
mock_stream.__enter__ = Mock(
|
||||
return_value=Mock(
|
||||
__iter__=lambda self: iter(
|
||||
[
|
||||
Mock(
|
||||
type="content_block_delta",
|
||||
delta=Mock(
|
||||
type="text_delta",
|
||||
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
mock_stream.__exit__ = Mock(return_value=False)
|
||||
client.messages.stream = Mock(return_value=mock_stream)
|
||||
|
||||
client.messages.create = Mock(
|
||||
return_value=Mock(
|
||||
content=[
|
||||
|
@ -2,318 +2,250 @@ import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import requests
|
||||
|
||||
from memory.common import discord, settings
|
||||
from memory.common import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_request():
|
||||
with patch("requests.Session.request") as mock:
|
||||
yield mock
|
||||
def mock_api_url():
|
||||
"""Mock the API URL to avoid using actual settings"""
|
||||
with patch(
|
||||
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_channels_response():
|
||||
return [
|
||||
{"name": "memory-errors", "id": "error_channel_id"},
|
||||
{"name": "memory-activity", "id": "activity_channel_id"},
|
||||
{"name": "memory-discoveries", "id": "discovery_channel_id"},
|
||||
{"name": "memory-chat", "id": "chat_channel_id"},
|
||||
]
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
|
||||
def test_get_api_url():
|
||||
"""Test API URL construction"""
|
||||
assert discord.get_api_url() == "http://testhost:9999"
|
||||
|
||||
|
||||
def test_discord_server_init(mock_session_request, mock_get_channels_response):
|
||||
# Mock the channels API call
|
||||
@patch("requests.post")
|
||||
def test_send_dm_success(mock_post, mock_api_url):
|
||||
"""Test successful DM sending"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_get_channels_response
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert server.server_id == "server123"
|
||||
assert server.server_name == "Test Server"
|
||||
assert hasattr(server, "channels")
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
|
||||
# Mock the channels API call
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_get_channels_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
|
||||
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
|
||||
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
|
||||
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
|
||||
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
|
||||
def test_setup_channels_create_missing(mock_session_request):
|
||||
# Mock get channels (empty) and create channel calls
|
||||
get_response = Mock()
|
||||
get_response.json.return_value = []
|
||||
get_response.raise_for_status.return_value = None
|
||||
|
||||
create_response = Mock()
|
||||
create_response.json.return_value = {"id": "new_channel_id"}
|
||||
create_response.raise_for_status.return_value = None
|
||||
|
||||
mock_session_request.side_effect = [
|
||||
get_response,
|
||||
create_response,
|
||||
create_response,
|
||||
create_response,
|
||||
create_response,
|
||||
]
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
|
||||
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
|
||||
|
||||
|
||||
def test_channel_properties():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {
|
||||
discord.ERROR_CHANNEL: "error_id",
|
||||
discord.ACTIVITY_CHANNEL: "activity_id",
|
||||
discord.DISCOVERY_CHANNEL: "discovery_id",
|
||||
discord.CHAT_CHANNEL: "chat_id",
|
||||
}
|
||||
|
||||
assert server.error_channel == "error_id"
|
||||
assert server.activity_channel == "activity_id"
|
||||
assert server.discovery_channel == "discovery_id"
|
||||
assert server.chat_channel == "chat_id"
|
||||
|
||||
|
||||
def test_channel_id_exists():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {"test-channel": "channel123"}
|
||||
|
||||
assert server.channel_id("test-channel") == "channel123"
|
||||
|
||||
|
||||
def test_channel_id_not_found():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {}
|
||||
|
||||
with pytest.raises(ValueError, match="Channel nonexistent not found"):
|
||||
server.channel_id("nonexistent")
|
||||
|
||||
|
||||
def test_send_message(mock_session_request):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
|
||||
server.send_message("channel123", "Hello World")
|
||||
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/channels/channel123/messages",
|
||||
data=None,
|
||||
json={"content": "Hello World"},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user_identifier": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
def test_create_channel(mock_session_request):
|
||||
@patch("requests.post")
|
||||
def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
"""Test DM sending when API returns failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"id": "new_channel_id"}
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
channel_id = server.create_channel("new-channel")
|
||||
|
||||
assert channel_id == "new_channel_id"
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/guilds/server123/channels",
|
||||
data=None,
|
||||
json={"name": "new-channel", "type": 0},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_create_channel_custom_type(mock_session_request):
|
||||
@patch("requests.post")
|
||||
def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
"""Test DM sending when HTTP error occurs"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"id": "voice_channel_id"}
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"""Test successful channel message broadcast"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
channel_id = server.create_channel("voice-channel", channel_type=2)
|
||||
|
||||
assert channel_id == "voice_channel_id"
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/guilds/server123/channels",
|
||||
data=None,
|
||||
json={"name": "voice-channel", "type": 2},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
def test_str_representation():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
server.server_name = "Test Server"
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
|
||||
def test_request_adds_headers(mock_session_request):
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
server.request("GET", "https://example.com", headers={"Custom": "header"})
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
expected_headers = {
|
||||
"Custom": "header",
|
||||
"Authorization": "Bot test_token_123",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
mock_session_request.assert_called_once_with(
|
||||
"GET", "https://example.com", headers=expected_headers
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_channels_url():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
|
||||
assert (
|
||||
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
@patch("requests.get")
|
||||
def test_get_bot_servers_success(mock_get):
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = [
|
||||
{"id": "server1", "name": "Server 1"},
|
||||
{"id": "server2", "name": "Server 2"},
|
||||
]
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
servers = discord.get_bot_servers()
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert len(servers) == 2
|
||||
assert servers[0] == {"id": "server1", "name": "Server 1"}
|
||||
mock_get.assert_called_once_with(
|
||||
"https://discord.com/api/v10/users/@me/guilds",
|
||||
headers={"Authorization": "Bot test_token"},
|
||||
)
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||
def test_get_bot_servers_no_token():
|
||||
assert discord.get_bot_servers() == []
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
@patch("requests.get")
|
||||
def test_get_bot_servers_exception(mock_get):
|
||||
mock_get.side_effect = requests.RequestException("API Error")
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
servers = discord.get_bot_servers()
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert servers == []
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.discord.get_bot_servers")
|
||||
@patch("memory.common.discord.DiscordServer")
|
||||
def test_load_servers(mock_discord_server_class, mock_get_servers):
|
||||
mock_get_servers.return_value = [
|
||||
{"id": "server1", "name": "Server 1"},
|
||||
{"id": "server2", "name": "Server 2"},
|
||||
]
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
discord.load_servers()
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert mock_discord_server_class.call_count == 2
|
||||
mock_discord_server_class.assert_any_call("server1", "Server 1")
|
||||
mock_discord_server_class.assert_any_call("server2", "Server 2")
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_broadcast_message():
|
||||
mock_server1 = Mock()
|
||||
mock_server2 = Mock()
|
||||
discord.servers = {"1": mock_server1, "2": mock_server2}
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
|
||||
"""Test successful metadata refresh"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": 5,
|
||||
"channels": 20,
|
||||
"users": 100,
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
discord.broadcast_message("test-channel", "Hello")
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
mock_server1.send_message.assert_called_once_with(
|
||||
mock_server1.channel_id.return_value, "Hello"
|
||||
)
|
||||
mock_server2.send_message.assert_called_once_with(
|
||||
mock_server2.channel_id.return_value, "Hello"
|
||||
assert result == {"servers": 5, "channels": 20, "users": 100}
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/refresh_metadata", timeout=30
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_broadcast_message_disabled():
|
||||
mock_server = Mock()
|
||||
discord.servers = {"1": mock_server}
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
|
||||
"""Test metadata refresh failure"""
|
||||
mock_post.side_effect = requests.RequestException("Failed to connect")
|
||||
|
||||
discord.broadcast_message("test-channel", "Hello")
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
mock_server.send_message.assert_not_called()
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
|
||||
"""Test metadata refresh with HTTP error"""
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
|
||||
def test_send_error_message(mock_broadcast):
|
||||
discord.send_error_message("Error occurred")
|
||||
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
|
||||
def test_send_activity_message(mock_broadcast):
|
||||
discord.send_activity_message("Activity update")
|
||||
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
|
||||
def test_send_discovery_message(mock_broadcast):
|
||||
discord.send_discovery_message("Discovery made")
|
||||
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
|
||||
def test_send_chat_message(mock_broadcast):
|
||||
discord.send_chat_message("Chat message")
|
||||
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
@ -323,69 +255,181 @@ def test_notify_task_failure_basic(mock_send_error):
|
||||
assert "**Error:** Something went wrong" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_args(mock_send_error):
|
||||
"""Test task failure notification with arguments"""
|
||||
discord.notify_task_failure(
|
||||
"test_task",
|
||||
"Error message",
|
||||
task_args=("arg1", "arg2"),
|
||||
task_kwargs={"key": "value"},
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Args:** `('arg1', 'arg2')`" in message
|
||||
assert "**Kwargs:** `{'key': 'value'}`" in message
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert traceback in message
|
||||
assert "Exception: test" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
long_error = "x" * 600 # Longer than 500 char limit
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert long_error[:500] in message
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
# The full 600-char error should not be present
|
||||
error_section = message.split("**Error:** ")[1].split("\n")[0]
|
||||
assert len(error_section) == 500
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
long_traceback = "x" * 1000 # Longer than 800 char limit
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
# The full 1000-char traceback should not be present
|
||||
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
|
||||
assert len(traceback_section) == 800
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
|
||||
) # Some buffer for formatting
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
discord.notify_task_failure("test_task", "Error message")
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_send_fails(mock_send_error):
|
||||
mock_send_error.side_effect = Exception("Discord API error")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
"""Test that exceptions in send_error_message don't propagate"""
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise, just log the error
|
||||
discord.notify_task_failure("test_task", "Error message")
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,channel_setting,message",
|
||||
[
|
||||
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
|
||||
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
|
||||
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
|
||||
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
|
||||
],
|
||||
)
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_convenience_functions_use_correct_channels(
|
||||
mock_broadcast, function, channel_setting, message
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
"""Test sending DM with special characters"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
"""Test broadcasting a long message"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
"""Test health check when response doesn't have status key"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
||||
|
435
tests/memory/common/test_discord_integration.py
Normal file
435
tests/memory/common/test_discord_integration.py
Normal file
@ -0,0 +1,435 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import requests
|
||||
|
||||
from memory.common import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_url():
|
||||
"""Mock the API URL to avoid using actual settings"""
|
||||
with patch(
|
||||
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
|
||||
def test_get_api_url():
|
||||
"""Test API URL construction"""
|
||||
assert discord.get_api_url() == "http://testhost:9999"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_success(mock_post, mock_api_url):
|
||||
"""Test successful DM sending"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user_identifier": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
"""Test DM sending when API returns failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
"""Test DM sending when HTTP error occurs"""
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"""Test successful channel message broadcast"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
|
||||
"""Test successful metadata refresh"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": 5,
|
||||
"channels": 20,
|
||||
"users": 100,
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result == {"servers": 5, "channels": 20, "users": 100}
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/refresh_metadata", timeout=30
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
|
||||
"""Test metadata refresh failure"""
|
||||
mock_post.side_effect = requests.RequestException("Failed to connect")
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
|
||||
"""Test metadata refresh with HTTP error"""
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
|
||||
def test_send_error_message(mock_broadcast):
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
|
||||
def test_send_activity_message(mock_broadcast):
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
|
||||
def test_send_discovery_message(mock_broadcast):
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
|
||||
def test_send_chat_message(mock_broadcast):
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "🚨 **Task Failed: test_task**" in message
|
||||
assert "**Error:** Something went wrong" in message
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_args(mock_send_error):
|
||||
"""Test task failure notification with arguments"""
|
||||
discord.notify_task_failure(
|
||||
"test_task",
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert "Exception: test" in message
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
# The full 600-char error should not be present
|
||||
error_section = message.split("**Error:** ")[1].split("\n")[0]
|
||||
assert len(error_section) == 500
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
# The full 1000-char traceback should not be present
|
||||
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
|
||||
assert len(traceback_section) == 800
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
|
||||
) # Some buffer for formatting
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
"""Test that exceptions in send_error_message don't propagate"""
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,channel_setting,message",
|
||||
[
|
||||
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
|
||||
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
|
||||
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
|
||||
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
|
||||
],
|
||||
)
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_convenience_functions_use_correct_channels(
|
||||
mock_broadcast, function, channel_setting, message
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
"""Test sending DM with special characters"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
"""Test broadcasting a long message"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
"""Test health check when response doesn't have status key"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
840
tests/memory/discord/test_collector.py
Normal file
840
tests/memory/discord/test_collector.py
Normal file
@ -0,0 +1,840 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
|
||||
import discord
|
||||
|
||||
from memory.discord.collector import (
|
||||
create_or_update_server,
|
||||
determine_channel_metadata,
|
||||
create_or_update_channel,
|
||||
create_or_update_user,
|
||||
determine_message_metadata,
|
||||
should_track_message,
|
||||
should_collect_bot_message,
|
||||
sync_guild_metadata,
|
||||
MessageCollector,
|
||||
)
|
||||
from memory.common.db.models.sources import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures for Discord objects
|
||||
@pytest.fixture
|
||||
def mock_guild():
|
||||
"""Mock Discord Guild object"""
|
||||
guild = Mock(spec=discord.Guild)
|
||||
guild.id = 123456789
|
||||
guild.name = "Test Server"
|
||||
guild.description = "A test server"
|
||||
guild.member_count = 42
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_channel():
|
||||
"""Mock Discord TextChannel object"""
|
||||
channel = Mock(spec=discord.TextChannel)
|
||||
channel.id = 987654321
|
||||
channel.name = "general"
|
||||
guild = Mock()
|
||||
guild.id = 123456789
|
||||
channel.guild = guild
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dm_channel():
|
||||
"""Mock Discord DMChannel object"""
|
||||
channel = Mock(spec=discord.DMChannel)
|
||||
channel.id = 111222333
|
||||
recipient = Mock()
|
||||
recipient.name = "TestUser"
|
||||
channel.recipient = recipient
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock Discord User object"""
|
||||
user = Mock(spec=discord.User)
|
||||
user.id = 444555666
|
||||
user.name = "testuser"
|
||||
user.display_name = "Test User"
|
||||
user.bot = False
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(mock_text_channel, mock_user):
|
||||
"""Mock Discord Message object"""
|
||||
message = Mock(spec=discord.Message)
|
||||
message.id = 777888999
|
||||
message.channel = mock_text_channel
|
||||
message.author = mock_user
|
||||
message.guild = mock_text_channel.guild
|
||||
message.content = "Test message"
|
||||
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
message.reference = None
|
||||
return message
|
||||
|
||||
|
||||
# Tests for create_or_update_server
|
||||
def test_create_or_update_server_creates_new(db_session, mock_guild):
|
||||
"""Test creating a new server record"""
|
||||
result = create_or_update_server(db_session, mock_guild)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mock_guild.id
|
||||
assert result.name == mock_guild.name
|
||||
assert result.description == mock_guild.description
|
||||
assert result.member_count == mock_guild.member_count
|
||||
|
||||
|
||||
def test_create_or_update_server_updates_existing(db_session, mock_guild):
|
||||
"""Test updating an existing server record"""
|
||||
# Create initial server
|
||||
server = DiscordServer(
|
||||
id=mock_guild.id,
|
||||
name="Old Name",
|
||||
description="Old Description",
|
||||
member_count=10,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
# Update with new data
|
||||
mock_guild.name = "New Name"
|
||||
mock_guild.description = "New Description"
|
||||
mock_guild.member_count = 50
|
||||
|
||||
result = create_or_update_server(db_session, mock_guild)
|
||||
|
||||
assert result.name == "New Name"
|
||||
assert result.description == "New Description"
|
||||
assert result.member_count == 50
|
||||
assert result.last_sync_at is not None
|
||||
|
||||
|
||||
def test_create_or_update_server_none_guild(db_session):
|
||||
"""Test with None guild"""
|
||||
result = create_or_update_server(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Tests for determine_channel_metadata
|
||||
def test_determine_channel_metadata_dm():
|
||||
"""Test metadata for DM channel"""
|
||||
channel = Mock(spec=discord.DMChannel)
|
||||
channel.recipient = Mock()
|
||||
channel.recipient.name = "TestUser"
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "dm"
|
||||
assert server_id is None
|
||||
assert "DM with TestUser" in name
|
||||
|
||||
|
||||
def test_determine_channel_metadata_dm_no_recipient():
|
||||
"""Test metadata for DM channel without recipient"""
|
||||
channel = Mock(spec=discord.DMChannel)
|
||||
channel.recipient = None
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "dm"
|
||||
assert name == "Unknown DM"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_group_dm():
|
||||
"""Test metadata for group DM channel"""
|
||||
channel = Mock(spec=discord.GroupChannel)
|
||||
channel.name = "Group Chat"
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "group_dm"
|
||||
assert server_id is None
|
||||
assert name == "Group Chat"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_group_dm_no_name():
|
||||
"""Test metadata for group DM without name"""
|
||||
channel = Mock(spec=discord.GroupChannel)
|
||||
channel.name = None
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert name == "Group DM"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_text_channel():
|
||||
"""Test metadata for text channel"""
|
||||
channel = Mock(spec=discord.TextChannel)
|
||||
channel.name = "general"
|
||||
channel.guild = Mock()
|
||||
channel.guild.id = 123
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "text"
|
||||
assert server_id == 123
|
||||
assert name == "general"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_voice_channel():
|
||||
"""Test metadata for voice channel"""
|
||||
channel = Mock(spec=discord.VoiceChannel)
|
||||
channel.name = "voice-chat"
|
||||
channel.guild = Mock()
|
||||
channel.guild.id = 456
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "voice"
|
||||
assert server_id == 456
|
||||
assert name == "voice-chat"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_thread():
|
||||
"""Test metadata for thread"""
|
||||
channel = Mock(spec=discord.Thread)
|
||||
channel.name = "thread-1"
|
||||
channel.guild = Mock()
|
||||
channel.guild.id = 789
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "thread"
|
||||
assert server_id == 789
|
||||
assert name == "thread-1"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_unknown():
|
||||
"""Test metadata for unknown channel type"""
|
||||
channel = Mock()
|
||||
channel.id = 999
|
||||
# Ensure the mock doesn't have a 'name' attribute
|
||||
del channel.name
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "unknown"
|
||||
assert name == "Unknown-999"
|
||||
|
||||
|
||||
# Tests for create_or_update_channel
|
||||
def test_create_or_update_channel_creates_new(
|
||||
db_session, mock_text_channel, mock_guild
|
||||
):
|
||||
"""Test creating a new channel record"""
|
||||
# Create the server first to satisfy foreign key constraint
|
||||
create_or_update_server(db_session, mock_guild)
|
||||
|
||||
result = create_or_update_channel(db_session, mock_text_channel)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mock_text_channel.id
|
||||
assert result.name == mock_text_channel.name
|
||||
assert result.channel_type == "text"
|
||||
|
||||
|
||||
def test_create_or_update_channel_updates_existing(db_session, mock_text_channel):
|
||||
"""Test updating an existing channel record"""
|
||||
# Create initial channel
|
||||
channel = DiscordChannel(
|
||||
id=mock_text_channel.id,
|
||||
name="old-name",
|
||||
channel_type="text",
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
# Update with new name
|
||||
mock_text_channel.name = "new-name"
|
||||
|
||||
result = create_or_update_channel(db_session, mock_text_channel)
|
||||
|
||||
assert result.name == "new-name"
|
||||
|
||||
|
||||
def test_create_or_update_channel_none_channel(db_session):
|
||||
"""Test with None channel"""
|
||||
result = create_or_update_channel(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Tests for create_or_update_user
|
||||
def test_create_or_update_user_creates_new(db_session, mock_user):
|
||||
"""Test creating a new user record"""
|
||||
result = create_or_update_user(db_session, mock_user)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mock_user.id
|
||||
assert result.username == mock_user.name
|
||||
assert result.display_name == mock_user.display_name
|
||||
|
||||
|
||||
def test_create_or_update_user_updates_existing(db_session, mock_user):
|
||||
"""Test updating an existing user record"""
|
||||
# Create initial user
|
||||
user = DiscordUser(
|
||||
id=mock_user.id,
|
||||
username="oldname",
|
||||
display_name="Old Display Name",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Update with new data
|
||||
mock_user.name = "newname"
|
||||
mock_user.display_name = "New Display Name"
|
||||
|
||||
result = create_or_update_user(db_session, mock_user)
|
||||
|
||||
assert result.username == "newname"
|
||||
assert result.display_name == "New Display Name"
|
||||
|
||||
|
||||
def test_create_or_update_user_none_user(db_session):
|
||||
"""Test with None user"""
|
||||
result = create_or_update_user(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Tests for determine_message_metadata
|
||||
def test_determine_message_metadata_default():
|
||||
"""Test metadata for default message"""
|
||||
message = Mock()
|
||||
message.reference = None
|
||||
message.channel = Mock()
|
||||
# Ensure channel doesn't have parent attribute
|
||||
del message.channel.parent
|
||||
|
||||
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||
|
||||
assert message_type == "default"
|
||||
assert reply_to_id is None
|
||||
assert thread_id is None
|
||||
|
||||
|
||||
def test_determine_message_metadata_reply():
|
||||
"""Test metadata for reply message"""
|
||||
message = Mock()
|
||||
message.reference = Mock()
|
||||
message.reference.message_id = 123456
|
||||
message.channel = Mock()
|
||||
|
||||
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||
|
||||
assert message_type == "reply"
|
||||
assert reply_to_id == 123456
|
||||
|
||||
|
||||
def test_determine_message_metadata_thread():
|
||||
"""Test metadata for message in thread"""
|
||||
message = Mock()
|
||||
message.reference = None
|
||||
message.channel = Mock()
|
||||
message.channel.id = 999
|
||||
message.channel.parent = Mock() # Has parent means it's a thread
|
||||
|
||||
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||
|
||||
assert thread_id == 999
|
||||
|
||||
|
||||
# Tests for should_track_message
|
||||
def test_should_track_message_server_disabled(db_session):
|
||||
"""Test when server has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=False)
|
||||
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
result = should_track_message(server, channel, user)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_should_track_message_channel_disabled(db_session):
|
||||
"""Test when channel has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", track_messages=False
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
result = should_track_message(server, channel, user)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_should_track_message_dm_allowed(db_session):
|
||||
"""Test DM tracking when user allows it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=True)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_should_track_message_dm_not_allowed(db_session):
|
||||
"""Test DM tracking when user doesn't allow it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=False)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_should_track_message_default_true(db_session):
|
||||
"""Test default tracking behavior"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", track_messages=True
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
result = should_track_message(server, channel, user)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
# Tests for should_collect_bot_message
|
||||
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", False)
|
||||
def test_should_collect_bot_message_bot_not_allowed():
|
||||
"""Test bot message collection when disabled"""
|
||||
message = Mock()
|
||||
message.author = Mock()
|
||||
message.author.bot = True
|
||||
|
||||
result = should_collect_bot_message(message)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", True)
|
||||
def test_should_collect_bot_message_bot_allowed():
|
||||
"""Test bot message collection when enabled"""
|
||||
message = Mock()
|
||||
message.author = Mock()
|
||||
message.author.bot = True
|
||||
|
||||
result = should_collect_bot_message(message)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_should_collect_bot_message_human():
|
||||
"""Test human message collection"""
|
||||
message = Mock()
|
||||
message.author = Mock()
|
||||
message.author.bot = False
|
||||
|
||||
result = should_collect_bot_message(message)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
# Tests for sync_guild_metadata
|
||||
@patch("memory.discord.collector.make_session")
|
||||
def test_sync_guild_metadata(mock_make_session, mock_guild):
|
||||
"""Test syncing guild metadata"""
|
||||
mock_session = Mock()
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock session.query().get() to return None (new server)
|
||||
mock_session.query.return_value.get.return_value = None
|
||||
|
||||
# Mock channels
|
||||
text_channel = Mock(spec=discord.TextChannel)
|
||||
text_channel.id = 1
|
||||
text_channel.name = "general"
|
||||
text_channel.guild = mock_guild
|
||||
|
||||
voice_channel = Mock(spec=discord.VoiceChannel)
|
||||
voice_channel.id = 2
|
||||
voice_channel.name = "voice"
|
||||
voice_channel.guild = mock_guild
|
||||
|
||||
mock_guild.channels = [text_channel, voice_channel]
|
||||
|
||||
sync_guild_metadata(mock_guild)
|
||||
|
||||
# Verify session.commit was called
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# Tests for MessageCollector class
|
||||
def test_message_collector_init():
|
||||
"""Test MessageCollector initialization"""
|
||||
collector = MessageCollector()
|
||||
|
||||
assert collector.command_prefix == "!memory_"
|
||||
assert collector.help_command is None
|
||||
assert collector.intents.message_content is True
|
||||
assert collector.intents.guilds is True
|
||||
assert collector.intents.members is True
|
||||
assert collector.intents.dm_messages is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_ready():
|
||||
"""Test on_ready event handler"""
|
||||
collector = MessageCollector()
|
||||
collector.user = Mock()
|
||||
collector.user.name = "TestBot"
|
||||
collector.guilds = [Mock(), Mock()]
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
|
||||
await collector.on_ready()
|
||||
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
@patch("memory.discord.collector.add_discord_message")
|
||||
async def test_on_message_success(mock_add_task, mock_make_session, mock_message):
|
||||
"""Test successful message handling"""
|
||||
mock_session = Mock()
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||
mock_session.query.return_value.get.return_value = None # New entities
|
||||
|
||||
collector = MessageCollector()
|
||||
await collector.on_message(mock_message)
|
||||
|
||||
# Verify task was queued
|
||||
mock_add_task.delay.assert_called_once()
|
||||
call_kwargs = mock_add_task.delay.call_args[1]
|
||||
assert call_kwargs["message_id"] == mock_message.id
|
||||
assert call_kwargs["channel_id"] == mock_message.channel.id
|
||||
assert call_kwargs["author_id"] == mock_message.author.id
|
||||
assert call_kwargs["content"] == mock_message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
async def test_on_message_bot_message_filtered(mock_make_session, mock_message):
|
||||
"""Test bot message filtering"""
|
||||
mock_message.author.bot = True
|
||||
|
||||
with patch(
|
||||
"memory.discord.collector.should_collect_bot_message", return_value=False
|
||||
):
|
||||
collector = MessageCollector()
|
||||
await collector.on_message(mock_message)
|
||||
|
||||
# Should not create session or queue task
|
||||
mock_make_session.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
async def test_on_message_error_handling(mock_make_session, mock_message):
|
||||
"""Test error handling in on_message"""
|
||||
mock_make_session.side_effect = Exception("Database error")
|
||||
|
||||
collector = MessageCollector()
|
||||
# Should not raise
|
||||
await collector.on_message(mock_message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.edit_discord_message")
|
||||
async def test_on_message_edit(mock_edit_task):
|
||||
"""Test message edit handler"""
|
||||
before = Mock()
|
||||
after = Mock()
|
||||
after.id = 123
|
||||
after.content = "Edited content"
|
||||
after.edited_at = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
collector = MessageCollector()
|
||||
await collector.on_message_edit(before, after)
|
||||
|
||||
mock_edit_task.delay.assert_called_once()
|
||||
call_kwargs = mock_edit_task.delay.call_args[1]
|
||||
assert call_kwargs["message_id"] == 123
|
||||
assert call_kwargs["content"] == "Edited content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_edit_error_handling():
|
||||
"""Test error handling in on_message_edit"""
|
||||
before = Mock()
|
||||
after = Mock()
|
||||
after.id = 123
|
||||
after.content = "Edited"
|
||||
after.edited_at = None # Will trigger datetime.now
|
||||
|
||||
with patch("memory.discord.collector.edit_discord_message") as mock_edit:
|
||||
mock_edit.delay.side_effect = Exception("Task error")
|
||||
|
||||
collector = MessageCollector()
|
||||
# Should not raise
|
||||
await collector.on_message_edit(before, after)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_servers_and_channels():
|
||||
"""Test syncing servers and channels"""
|
||||
guild1 = Mock()
|
||||
guild2 = Mock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild1, guild2]
|
||||
|
||||
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||
await collector.sync_servers_and_channels()
|
||||
|
||||
assert mock_sync.call_count == 2
|
||||
mock_sync.assert_any_call(guild1)
|
||||
mock_sync.assert_any_call(guild2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
async def test_refresh_metadata(mock_make_session):
|
||||
"""Test metadata refresh"""
|
||||
mock_session = Mock()
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||
mock_session.query.return_value.get.return_value = None
|
||||
|
||||
guild = Mock()
|
||||
guild.id = 123
|
||||
guild.name = "Test"
|
||||
guild.channels = []
|
||||
guild.members = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
collector.intents = Mock()
|
||||
collector.intents.members = False
|
||||
|
||||
result = await collector.refresh_metadata()
|
||||
|
||||
assert result["servers_updated"] == 1
|
||||
assert result["channels_updated"] == 0
|
||||
assert result["users_updated"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_id():
|
||||
"""Test getting user by ID"""
|
||||
user = Mock()
|
||||
user.id = 123
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = Mock(return_value=user)
|
||||
|
||||
result = await collector.get_user(123)
|
||||
|
||||
assert result == user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_username():
|
||||
"""Test getting user by username"""
|
||||
member = Mock()
|
||||
member.name = "testuser"
|
||||
member.display_name = "Test User"
|
||||
member.discriminator = "1234"
|
||||
|
||||
guild = Mock()
|
||||
guild.members = [member]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_user("testuser")
|
||||
|
||||
assert result == member
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_not_found():
|
||||
"""Test getting non-existent user"""
|
||||
collector = MessageCollector()
|
||||
collector.guilds = []
|
||||
|
||||
with patch.object(collector, "get_user", return_value=None):
|
||||
with patch.object(
|
||||
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
|
||||
):
|
||||
result = await collector.get_user(999)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_channel_by_name():
|
||||
"""Test getting channel by name"""
|
||||
channel = Mock(spec=discord.TextChannel)
|
||||
channel.name = "general"
|
||||
|
||||
guild = Mock()
|
||||
guild.channels = [channel]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_channel_by_name("general")
|
||||
|
||||
assert result == channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_channel_by_name_not_found():
|
||||
"""Test getting non-existent channel"""
|
||||
guild = Mock()
|
||||
guild.channels = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_channel_by_name("nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_channel():
|
||||
"""Test creating a channel"""
|
||||
guild = Mock()
|
||||
guild.name = "Test Server"
|
||||
new_channel = Mock()
|
||||
guild.create_text_channel = AsyncMock(return_value=new_channel)
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_guild = Mock(return_value=guild)
|
||||
|
||||
result = await collector.create_channel("new-channel", guild_id=123)
|
||||
|
||||
assert result == new_channel
|
||||
guild.create_text_channel.assert_called_once_with("new-channel")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_channel_no_guild():
|
||||
"""Test creating channel when no guild available"""
|
||||
collector = MessageCollector()
|
||||
collector.get_guild = Mock(return_value=None)
|
||||
collector.guilds = []
|
||||
|
||||
result = await collector.create_channel("new-channel")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_success():
|
||||
"""Test sending DM successfully"""
|
||||
user = Mock()
|
||||
user.send = AsyncMock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = AsyncMock(return_value=user)
|
||||
|
||||
result = await collector.send_dm(123, "Hello!")
|
||||
|
||||
assert result is True
|
||||
user.send.assert_called_once_with("Hello!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_user_not_found():
|
||||
"""Test sending DM when user not found"""
|
||||
collector = MessageCollector()
|
||||
collector.get_user = AsyncMock(return_value=None)
|
||||
|
||||
result = await collector.send_dm(123, "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_exception():
|
||||
"""Test sending DM with exception"""
|
||||
user = Mock()
|
||||
user.send = AsyncMock(side_effect=Exception("Send failed"))
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = AsyncMock(return_value=user)
|
||||
|
||||
result = await collector.send_dm(123, "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
async def test_send_to_channel_success():
|
||||
"""Test sending to channel successfully"""
|
||||
channel = Mock()
|
||||
channel.send = AsyncMock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_channel_by_name = AsyncMock(return_value=channel)
|
||||
|
||||
result = await collector.send_to_channel("general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
channel.send.assert_called_once_with("Announcement!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
async def test_send_to_channel_notifications_disabled():
|
||||
"""Test sending to channel when notifications disabled"""
|
||||
collector = MessageCollector()
|
||||
|
||||
result = await collector.send_to_channel("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
async def test_send_to_channel_not_found():
|
||||
"""Test sending to non-existent channel"""
|
||||
collector = MessageCollector()
|
||||
collector.get_channel_by_name = AsyncMock(return_value=None)
|
||||
|
||||
result = await collector.send_to_channel("nonexistent", "Message")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
async def test_run_collector():
|
||||
"""Test running the collector"""
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
|
||||
mock_collector = Mock()
|
||||
mock_collector.start = AsyncMock()
|
||||
mock_collector_class.return_value = mock_collector
|
||||
|
||||
await run_collector()
|
||||
|
||||
mock_collector.start.assert_called_once_with("test_token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||
async def test_run_collector_no_token():
|
||||
"""Test running collector without token"""
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
# Should return early without raising
|
||||
await run_collector()
|
607
tests/memory/workers/tasks/test_discord_tasks.py
Normal file
607
tests/memory/workers/tasks/test_discord_tasks.py
Normal file
@ -0,0 +1,607 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import (
|
||||
DiscordMessage,
|
||||
DiscordUser,
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
)
|
||||
from memory.workers.tasks import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(db_session):
|
||||
"""Create a Discord user for testing."""
|
||||
user = DiscordUser(
|
||||
id=123456789,
|
||||
username="testuser",
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_server(db_session):
|
||||
"""Create a Discord server for testing."""
|
||||
server = DiscordServer(
|
||||
id=987654321,
|
||||
name="Test Server",
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_channel(db_session, mock_discord_server):
|
||||
"""Create a Discord channel for testing."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="test-channel",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_data(mock_discord_user, mock_discord_channel):
|
||||
"""Sample message data for testing."""
|
||||
return {
|
||||
"message_id": 999888777,
|
||||
"channel_id": mock_discord_channel.id,
|
||||
"author_id": mock_discord_user.id,
|
||||
"content": "This is a test Discord message with enough content to be processed.",
|
||||
"sent_at": "2024-01-01T12:00:00Z",
|
||||
"server_id": None,
|
||||
"message_reference_id": None,
|
||||
}
|
||||
|
||||
|
||||
def test_get_prev_returns_previous_messages(
|
||||
db_session, mock_discord_user, mock_discord_channel
|
||||
):
|
||||
"""Test that get_prev returns previous messages in order."""
|
||||
# Create previous messages
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="First message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash1" + bytes(26),
|
||||
)
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Second message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash2" + bytes(26),
|
||||
)
|
||||
msg3 = DiscordMessage(
|
||||
message_id=3,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Third message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash3" + bytes(26),
|
||||
)
|
||||
db_session.add_all([msg1, msg2, msg3])
|
||||
db_session.commit()
|
||||
|
||||
# Get previous messages before 10:15
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
mock_discord_channel.id,
|
||||
datetime(2024, 1, 1, 10, 15, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0] == "testuser: First message"
|
||||
assert result[1] == "testuser: Second message"
|
||||
assert result[2] == "testuser: Third message"
|
||||
|
||||
|
||||
def test_get_prev_limits_context_window(
|
||||
db_session, mock_discord_user, mock_discord_channel
|
||||
):
|
||||
"""Test that get_prev respects DISCORD_CONTEXT_WINDOW setting."""
|
||||
# Create 15 messages (more than the default context window of 10)
|
||||
for i in range(15):
|
||||
msg = DiscordMessage(
|
||||
message_id=i,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content=f"Message {i}",
|
||||
sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=f"hash{i}".encode() + bytes(27),
|
||||
)
|
||||
db_session.add(msg)
|
||||
db_session.commit()
|
||||
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
mock_discord_channel.id,
|
||||
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# Should only return last 10 messages
|
||||
assert len(result) == 10
|
||||
assert result[0] == "testuser: Message 5" # Oldest in window
|
||||
assert result[-1] == "testuser: Message 14" # Most recent
|
||||
|
||||
|
||||
def test_get_prev_empty_channel(db_session, mock_discord_channel):
|
||||
"""Test get_prev with no previous messages."""
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
mock_discord_channel.id,
|
||||
datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_normal_message(
|
||||
db_session, mock_discord_user, mock_discord_server, mock_discord_channel
|
||||
):
|
||||
"""Test should_process returns True for normal messages."""
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is True
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||
def test_should_process_disabled():
|
||||
"""Test should_process returns False when processing is disabled."""
|
||||
message = Mock()
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_should_process_notifications_disabled():
|
||||
"""Test should_process returns False when notifications are disabled."""
|
||||
message = Mock()
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_server_ignored(
|
||||
db_session, mock_discord_user, mock_discord_channel
|
||||
):
|
||||
"""Test should_process returns False when server has ignore_messages=True."""
|
||||
server = DiscordServer(
|
||||
id=123,
|
||||
name="Ignored Server",
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
server_id=server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_channel_ignored(
|
||||
db_session, mock_discord_user, mock_discord_server
|
||||
):
|
||||
"""Test should_process returns False when channel has ignore_messages=True."""
|
||||
channel = DiscordChannel(
|
||||
id=456,
|
||||
name="ignored-channel",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_user_ignored(
|
||||
db_session, mock_discord_server, mock_discord_channel
|
||||
):
|
||||
"""Test should_process returns False when user has ignore_messages=True."""
|
||||
user = DiscordUser(
|
||||
id=789,
|
||||
username="ignoreduser",
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
def test_add_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
"""Test successful Discord message addition."""
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert "discordmessage_id" in result
|
||||
|
||||
# Verify the message was created in the database
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message is not None
|
||||
assert message.content == sample_message_data["content"]
|
||||
assert message.message_type == "default"
|
||||
assert message.reply_to_message_id is None
|
||||
|
||||
|
||||
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
|
||||
"""Test adding a Discord message that is a reply."""
|
||||
sample_message_data["message_reference_id"] = 111222333
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message.message_type == "reply"
|
||||
assert message.reply_to_message_id == 111222333
|
||||
|
||||
|
||||
def test_add_discord_message_already_exists(db_session, sample_message_data, qdrant):
|
||||
"""Test adding a message that already exists."""
|
||||
# Add the message once
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Try to add it again
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
assert result["status"] == "already_exists"
|
||||
assert result["message_id"] == sample_message_data["message_id"]
|
||||
|
||||
# Verify no duplicate was created
|
||||
messages = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.all()
|
||||
)
|
||||
assert len(messages) == 1
|
||||
|
||||
|
||||
def test_add_discord_message_with_context(
|
||||
db_session, sample_message_data, mock_discord_user, qdrant
|
||||
):
|
||||
"""Test that message is added successfully when previous messages exist."""
|
||||
# Add a previous message
|
||||
prev_msg = DiscordMessage(
|
||||
message_id=111111,
|
||||
channel_id=sample_message_data["channel_id"],
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Previous message",
|
||||
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"prev" + bytes(28),
|
||||
)
|
||||
db_session.add(prev_msg)
|
||||
db_session.commit()
|
||||
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message is not None
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.discord.process_discord_message")
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_add_discord_message_triggers_processing(
|
||||
mock_process,
|
||||
db_session,
|
||||
sample_message_data,
|
||||
mock_discord_server,
|
||||
mock_discord_channel,
|
||||
qdrant,
|
||||
):
|
||||
"""Test that add_discord_message triggers process_discord_message when conditions are met."""
|
||||
mock_process.delay = Mock()
|
||||
sample_message_data["server_id"] = mock_discord_server.id
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Verify process_discord_message.delay was called
|
||||
mock_process.delay.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.discord.process_discord_message")
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||
def test_add_discord_message_no_processing_when_disabled(
|
||||
mock_process, db_session, sample_message_data, qdrant
|
||||
):
|
||||
"""Test that process_discord_message is not called when processing is disabled."""
|
||||
mock_process.delay = Mock()
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
mock_process.delay.assert_not_called()
|
||||
|
||||
|
||||
def test_edit_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
"""Test successful Discord message edit."""
|
||||
# First add the message
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Edit it
|
||||
new_content = (
|
||||
"This is the edited content with enough text to be meaningful and processed."
|
||||
)
|
||||
edited_at = "2024-01-01T13:00:00Z"
|
||||
|
||||
result = discord.edit_discord_message(
|
||||
sample_message_data["message_id"],
|
||||
new_content,
|
||||
edited_at,
|
||||
)
|
||||
|
||||
assert result["status"] == "processed"
|
||||
|
||||
# Verify the message was updated
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message.content == new_content
|
||||
assert message.edited_at is not None
|
||||
|
||||
|
||||
def test_edit_discord_message_not_found(db_session):
|
||||
"""Test editing a message that doesn't exist."""
|
||||
result = discord.edit_discord_message(
|
||||
999999,
|
||||
"New content",
|
||||
"2024-01-01T13:00:00Z",
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Message not found"
|
||||
assert result["message_id"] == 999999
|
||||
|
||||
|
||||
def test_edit_discord_message_updates_context(
|
||||
db_session, sample_message_data, mock_discord_user, qdrant
|
||||
):
|
||||
"""Test that editing a message works correctly."""
|
||||
# Add previous message and the message to be edited
|
||||
prev_msg = DiscordMessage(
|
||||
message_id=111111,
|
||||
channel_id=sample_message_data["channel_id"],
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Previous message",
|
||||
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"prev" + bytes(28),
|
||||
)
|
||||
db_session.add(prev_msg)
|
||||
db_session.commit()
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Edit the message
|
||||
result = discord.edit_discord_message(
|
||||
sample_message_data["message_id"],
|
||||
"Edited content that should have context updated properly.",
|
||||
"2024-01-01T13:00:00Z",
|
||||
)
|
||||
|
||||
# Verify message was updated
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert (
|
||||
message.content == "Edited content that should have context updated properly."
|
||||
)
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
"""Test processing a Discord message."""
|
||||
# Add a message first
|
||||
add_result = discord.add_discord_message(**sample_message_data)
|
||||
message_id = add_result["discordmessage_id"]
|
||||
|
||||
# Process it
|
||||
result = discord.process_discord_message(message_id)
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert result["message_id"] == message_id
|
||||
|
||||
|
||||
def test_process_discord_message_not_found(db_session):
|
||||
"""Test processing a message that doesn't exist."""
|
||||
result = discord.process_discord_message(999999)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Message not found"
|
||||
assert result["message_id"] == 999999
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sent_at_str,expected_hour",
|
||||
[
|
||||
("2024-01-01T12:00:00Z", 12),
|
||||
("2024-01-01T00:00:00+00:00", 0),
|
||||
("2024-01-01T23:59:59Z", 23),
|
||||
],
|
||||
)
|
||||
def test_add_discord_message_datetime_parsing(
|
||||
db_session, sample_message_data, sent_at_str, expected_hour, qdrant
|
||||
):
|
||||
"""Test that various datetime formats are parsed correctly."""
|
||||
sample_message_data["sent_at"] = sent_at_str
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message.sent_at.hour == expected_hour
|
||||
|
||||
|
||||
def test_add_discord_message_unique_hash(db_session, sample_message_data, qdrant):
|
||||
"""Test that message hash includes message_id for uniqueness."""
|
||||
# Add first message
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Try to add another message with same content but different message_id
|
||||
sample_message_data["message_id"] = 888777666
|
||||
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Should succeed because hash includes message_id
|
||||
assert result["status"] == "processed"
|
||||
|
||||
# Verify both messages exist
|
||||
messages = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(content=sample_message_data["content"])
|
||||
.all()
|
||||
)
|
||||
assert len(messages) == 2
|
||||
|
||||
|
||||
def test_get_prev_only_returns_messages_from_same_channel(
|
||||
db_session, mock_discord_user, mock_discord_server
|
||||
):
|
||||
"""Test that get_prev only returns messages from the specified channel."""
|
||||
# Create two channels
|
||||
channel1 = DiscordChannel(
|
||||
id=111,
|
||||
name="channel-1",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
)
|
||||
channel2 = DiscordChannel(
|
||||
id=222,
|
||||
name="channel-2",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
)
|
||||
db_session.add_all([channel1, channel2])
|
||||
db_session.commit()
|
||||
|
||||
# Add messages to both channels
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=channel1.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Message in channel 1",
|
||||
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash1" + bytes(26),
|
||||
)
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=channel2.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Message in channel 2",
|
||||
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash2" + bytes(26),
|
||||
)
|
||||
db_session.add_all([msg1, msg2])
|
||||
db_session.commit()
|
||||
|
||||
# Get previous messages for channel 1
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
channel1.id, # type: ignore
|
||||
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# Should only return message from channel 1
|
||||
assert len(result) == 1
|
||||
assert "Message in channel 1" in result[0]
|
||||
assert "Message in channel 2" not in result[0]
|
Loading…
x
Reference in New Issue
Block a user