From 99d3843f477fb92838c8653464669fc7205262ff Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 13 Oct 2025 03:23:20 +0200 Subject: [PATCH] move to general LLM providers --- .../20251012_222827_add_discord_models.py | 17 +- requirements/requirements-common.txt | 2 +- src/memory/api/search/scorer.py | 2 +- src/memory/common/celery_app.py | 1 + src/memory/common/db/models/sources.py | 17 +- src/memory/common/llms.py | 122 --- src/memory/common/llms/__init__.py | 79 ++ src/memory/common/llms/anthropic_provider.py | 451 ++++++++++ src/memory/common/llms/base.py | 561 ++++++++++++ src/memory/common/llms/openai_provider.py | 388 ++++++++ src/memory/common/llms/tools/__init__.py | 36 + src/memory/common/llms/tools/ping.py | 42 + src/memory/common/messages.py | 9 + src/memory/common/settings.py | 1 + src/memory/common/summarizer.py | 2 +- src/memory/discord/collector.py | 2 +- src/memory/workers/tasks/discord.py | 48 +- src/memory/workers/tasks/scheduled_calls.py | 5 +- tests/conftest.py | 21 + tests/memory/common/test_discord.py | 556 ++++++------ .../memory/common/test_discord_integration.py | 435 +++++++++ tests/memory/discord/test_collector.py | 840 ++++++++++++++++++ .../workers/tasks/test_discord_tasks.py | 607 +++++++++++++ 23 files changed, 3844 insertions(+), 400 deletions(-) delete mode 100644 src/memory/common/llms.py create mode 100644 src/memory/common/llms/__init__.py create mode 100644 src/memory/common/llms/anthropic_provider.py create mode 100644 src/memory/common/llms/base.py create mode 100644 src/memory/common/llms/openai_provider.py create mode 100644 src/memory/common/llms/tools/__init__.py create mode 100644 src/memory/common/llms/tools/ping.py create mode 100644 src/memory/common/messages.py create mode 100644 tests/memory/common/test_discord_integration.py create mode 100644 tests/memory/discord/test_collector.py create mode 100644 tests/memory/workers/tasks/test_discord_tasks.py diff --git a/db/migrations/versions/20251012_222827_add_discord_models.py b/db/migrations/versions/20251012_222827_add_discord_models.py index 91d661b..deb73e1 100644 --- a/db/migrations/versions/20251012_222827_add_discord_models.py +++ b/db/migrations/versions/20251012_222827_add_discord_models.py @@ -26,9 +26,12 @@ def upgrade() -> None: sa.Column("name", sa.Text(), nullable=False), sa.Column("description", sa.Text(), nullable=True), sa.Column("member_count", sa.Integer(), nullable=True), + sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True), sa.Column( - "track_messages", sa.Boolean(), server_default="true", nullable=False + "ignore_messages", sa.Boolean(), server_default="false", nullable=True ), + sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True), + sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True), sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True), sa.Column( "created_at", @@ -56,7 +59,12 @@ def upgrade() -> None: sa.Column("server_id", sa.BigInteger(), nullable=True), sa.Column("name", sa.Text(), nullable=False), sa.Column("channel_type", sa.Text(), nullable=False), - sa.Column("track_messages", sa.Boolean(), nullable=True), + sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True), + sa.Column( + "ignore_messages", sa.Boolean(), server_default="false", nullable=True + ), + sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True), + sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), @@ -84,9 +92,12 @@ def upgrade() -> None: sa.Column("username", sa.Text(), nullable=False), sa.Column("display_name", sa.Text(), nullable=True), sa.Column("system_user_id", sa.Integer(), nullable=True), + sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True), sa.Column( - "allow_dm_tracking", sa.Boolean(), server_default="true", nullable=False + "ignore_messages", sa.Boolean(), server_default="false", nullable=True ), + sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True), + sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt index 8ecc0b7..f88aec2 100644 --- a/requirements/requirements-common.txt +++ b/requirements/requirements-common.txt @@ -5,7 +5,7 @@ alembic==1.13.1 dotenv==0.9.9 voyageai==0.3.2 qdrant-client==1.9.0 -anthropic==0.18.1 +anthropic==0.69.0 openai==1.25.0 # Pin the httpx version, as newer versions break the anthropic client httpx==0.27.0 diff --git a/src/memory/api/search/scorer.py b/src/memory/api/search/scorer.py index d978f98..5049d95 100644 --- a/src/memory/api/search/scorer.py +++ b/src/memory/api/search/scorer.py @@ -40,7 +40,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk: prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text) try: response = await asyncio.to_thread( - llms.call, + llms.summarize, prompt, settings.RANKER_MODEL, images=images, diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 75a3861..cc930da 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -15,6 +15,7 @@ SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls" DISCORD_ROOT = "memory.workers.tasks.discord" ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message" EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message" +PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message" SYNC_NOTES = f"{NOTES_ROOT}.sync_notes" SYNC_NOTE = f"{NOTES_ROOT}.sync_note" diff --git a/src/memory/common/db/models/sources.py b/src/memory/common/db/models/sources.py index 9d018e5..1cbb9e4 100644 --- a/src/memory/common/db/models/sources.py +++ b/src/memory/common/db/models/sources.py @@ -127,7 +127,15 @@ class EmailAccount(Base): ) -class DiscordServer(Base): +class MessageProcessor: + track_messages = Column(Boolean, nullable=False, server_default="true") + ignore_messages = Column(Boolean, nullable=True, default=False) + + allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + + +class DiscordServer(Base, MessageProcessor): """Discord server configuration and metadata""" __tablename__ = "discord_servers" @@ -138,7 +146,6 @@ class DiscordServer(Base): member_count = Column(Integer) # Collection settings - track_messages = Column(Boolean, nullable=False, server_default="true") last_sync_at = Column(DateTime(timezone=True)) created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now()) @@ -152,7 +159,7 @@ class DiscordServer(Base): ) -class DiscordChannel(Base): +class DiscordChannel(Base, MessageProcessor): """Discord channel metadata and configuration""" __tablename__ = "discord_channels" @@ -163,7 +170,6 @@ class DiscordChannel(Base): channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm" # Collection settings (null = inherit from server) - track_messages = Column(Boolean, nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now()) @@ -171,7 +177,7 @@ class DiscordChannel(Base): __table_args__ = (Index("discord_channels_server_idx", "server_id"),) -class DiscordUser(Base): +class DiscordUser(Base, MessageProcessor): """Discord user metadata and preferences""" __tablename__ = "discord_users" @@ -184,7 +190,6 @@ class DiscordUser(Base): system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Basic DM settings - allow_dm_tracking = Column(Boolean, nullable=False, server_default="true") created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now()) diff --git a/src/memory/common/llms.py b/src/memory/common/llms.py deleted file mode 100644 index 8b88b75..0000000 --- a/src/memory/common/llms.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -import base64 -import io -from typing import Any -from PIL import Image - -from memory.common import settings, tokens - -logger = logging.getLogger(__name__) - -SYSTEM_PROMPT = """ -You are a helpful assistant that creates concise summaries and identifies key topics. -""" - - -def encode_image(image: Image.Image) -> str: - """Encode PIL Image to base64 string.""" - buffer = io.BytesIO() - # Convert to RGB if necessary (for RGBA, etc.) - if image.mode != "RGB": - image = image.convert("RGB") - image.save(buffer, format="JPEG") - return base64.b64encode(buffer.getvalue()).decode("utf-8") - - -def call_openai( - prompt: str, - model: str, - images: list[Image.Image] = [], - system_prompt: str = SYSTEM_PROMPT, -) -> str: - """Call OpenAI API for summarization.""" - import openai - - client = openai.OpenAI(api_key=settings.OPENAI_API_KEY) - try: - user_content: Any = [{"type": "text", "text": prompt}] - if images: - for image in images: - encoded_image = encode_image(image) - user_content.append( - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, - } - ) - - response = client.chat.completions.create( - model=model.split("/")[1], - messages=[ - { - "role": "system", - "content": system_prompt, - }, - {"role": "user", "content": user_content}, - ], - temperature=0.3, - max_tokens=2048, - ) - return response.choices[0].message.content or "" - except Exception as e: - logger.error(f"OpenAI API error: {e}") - raise - - -def call_anthropic( - prompt: str, - model: str, - images: list[Image.Image] = [], - system_prompt: str = SYSTEM_PROMPT, -) -> str: - """Call Anthropic API for summarization.""" - import anthropic - - client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY) - try: - # Prepare the message content - content: Any = [{"type": "text", "text": prompt}] - if images: - # Add images if provided - for image in images: - encoded_image = encode_image(image) - content.append( - { # type: ignore - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": encoded_image, - }, - } - ) - - response = client.messages.create( - model=model.split("/")[1], - messages=[{"role": "user", "content": content}], # type: ignore - system=system_prompt, - temperature=0.3, - max_tokens=2048, - ) - return response.content[0].text - except Exception as e: - logger.error(f"Anthropic API error: {e}") - raise - - -def call( - prompt: str, - model: str, - images: list[Image.Image] = [], - system_prompt: str = SYSTEM_PROMPT, -) -> str: - if model.startswith("anthropic"): - return call_anthropic(prompt, model, images, system_prompt) - return call_openai(prompt, model, images, system_prompt) - - -def truncate(content: str, target_tokens: int) -> str: - target_chars = target_tokens * tokens.CHARS_PER_TOKEN - if len(content) > target_chars: - return content[:target_chars].rsplit(" ", 1)[0] + "..." - return content diff --git a/src/memory/common/llms/__init__.py b/src/memory/common/llms/__init__.py new file mode 100644 index 0000000..68b6bd9 --- /dev/null +++ b/src/memory/common/llms/__init__.py @@ -0,0 +1,79 @@ +"""LLM provider module for unified LLM access.""" + +# Legacy imports for backwards compatibility +import logging + +from PIL import Image + + +# New provider system +from memory.common.llms.base import ( + BaseLLMProvider, + ImageContent, + LLMSettings, + Message, + MessageContent, + MessageRole, + StreamEvent, + TextContent, + ThinkingContent, + ToolDefinition, + ToolResultContent, + ToolUseContent, + create_provider, +) +from memory.common import tokens + +__all__ = [ + "BaseLLMProvider", + "Message", + "MessageRole", + "MessageContent", + "TextContent", + "ImageContent", + "ToolUseContent", + "ToolResultContent", + "ThinkingContent", + "ToolDefinition", + "StreamEvent", + "LLMSettings", + "create_provider", +] + +logger = logging.getLogger(__name__) + + +def summarize( + prompt: str, + model: str, + images: list[Image.Image] = [], + system_prompt: str = "", +) -> str: + provider = create_provider(model=model) + try: + # Build message content + content: list[MessageContent] = [TextContent(text=prompt)] + for image in images: + content.append(ImageContent(image=image)) + + messages = [Message(role=MessageRole.USER, content=content)] + settings_obj = LLMSettings(temperature=0.3, max_tokens=2048) + + res = provider.run_with_tools( + messages=messages, + system_prompt=system_prompt + or "You are a helpful assistant that creates concise summaries and identifies key topics.", + settings=settings_obj, + tools={}, + ) + return res.response or "" + except Exception as e: + logger.error(f"Anthropic API error: {e}") + raise + + +def truncate(content: str, target_tokens: int) -> str: + target_chars = target_tokens * tokens.CHARS_PER_TOKEN + if len(content) > target_chars: + return content[:target_chars].rsplit(" ", 1)[0] + "..." + return content diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py new file mode 100644 index 0000000..a46236f --- /dev/null +++ b/src/memory/common/llms/anthropic_provider.py @@ -0,0 +1,451 @@ +"""Anthropic LLM provider implementation.""" + +import json +import logging +from typing import Any, AsyncIterator, Iterator, Optional + +import anthropic + +from memory.common.llms.base import ( + BaseLLMProvider, + ImageContent, + LLMSettings, + Message, + MessageRole, + StreamEvent, + ToolDefinition, + ToolUseContent, + ThinkingContent, + TextContent, +) + +logger = logging.getLogger(__name__) + + +class AnthropicProvider(BaseLLMProvider): + """Anthropic LLM provider with streaming, tool support, and extended thinking.""" + + # Models that support extended thinking + THINKING_MODELS = { + "claude-opus-4", + "claude-opus-4-1", + "claude-sonnet-4-0", + "claude-sonnet-3-7", + "claude-sonnet-4-5", + } + + def __init__(self, api_key: str, model: str, enable_thinking: bool = False): + """ + Initialize the Anthropic provider. + + Args: + api_key: Anthropic API key + model: Model identifier + enable_thinking: Enable extended thinking for supported models + """ + super().__init__(api_key, model) + self.enable_thinking = enable_thinking + self._async_client: Optional[anthropic.AsyncAnthropic] = None + + def _initialize_client(self) -> anthropic.Anthropic: + """Initialize the Anthropic client.""" + return anthropic.Anthropic(api_key=self.api_key) + + @property + def async_client(self) -> anthropic.AsyncAnthropic: + """Lazy-load the async client.""" + if self._async_client is None: + self._async_client = anthropic.AsyncAnthropic(api_key=self.api_key) + return self._async_client + + def _convert_image_content(self, content: ImageContent) -> dict[str, Any]: + """Convert ImageContent to Anthropic's base64 source format.""" + encoded_image = self.encode_image(content.image) + return { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": encoded_image, + }, + } + + def _convert_message(self, message: Message) -> dict[str, Any]: + converted = message.to_dict() + if converted["role"] == MessageRole.ASSISTANT and isinstance( + converted["content"], list + ): + content = sorted( + converted["content"], key=lambda x: x["type"] != "thinking" + ) + return converted | {"content": content} + return converted + + def _should_include_message(self, message: Message) -> bool: + """Filter out system messages (handled separately in Anthropic).""" + return message.role != MessageRole.SYSTEM + + def _supports_thinking(self) -> bool: + """Check if the current model supports extended thinking.""" + model_lower = self.model.lower() + return any(supported in model_lower for supported in self.THINKING_MODELS) + + def _build_request_kwargs( + self, + messages: list[Message], + system_prompt: str | None, + tools: list[ToolDefinition] | None, + settings: LLMSettings, + ) -> dict[str, Any]: + """Build common request kwargs for API calls.""" + anthropic_messages = self._convert_messages(messages) + + kwargs: dict[str, Any] = { + "model": self.model, + "messages": anthropic_messages, + "temperature": settings.temperature, + "max_tokens": settings.max_tokens, + } + + # Only include top_p if explicitly set + if settings.top_p is not None: + kwargs["top_p"] = settings.top_p + + if system_prompt: + kwargs["system"] = system_prompt + + if settings.stop_sequences: + kwargs["stop_sequences"] = settings.stop_sequences + + if tools: + kwargs["tools"] = self._convert_tools(tools) + + # Enable extended thinking if requested and model supports it + if self.enable_thinking and self._supports_thinking(): + thinking_budget = min(10000, settings.max_tokens - 1024) + if thinking_budget >= 1024: + kwargs["thinking"] = { + "type": "enabled", + "budget_tokens": thinking_budget, + } + # When thinking is enabled: temperature must be 1, can't use top_p + kwargs["temperature"] = 1.0 + kwargs.pop("top_p", None) + + return kwargs + + def _handle_stream_event( + self, event: Any, current_tool_use: dict[str, Any] | None + ) -> tuple[StreamEvent | None, dict[str, Any] | None]: + """ + Handle a streaming event and return StreamEvent and updated tool state. + + Returns: + Tuple of (StreamEvent or None, updated current_tool_use or None) + """ + event_type = getattr(event, "type", None) + + # Handle error events + if event_type == "error": + error = getattr(event, "error", None) + error_msg = str(error) if error else "Unknown error" + return StreamEvent(type="error", data=error_msg), current_tool_use + + if event_type == "content_block_start": + block = getattr(event, "content_block", None) + if not block: + return None, current_tool_use + + block_type = getattr(block, "type", None) + + # Handle various tool types (tool_use, mcp_tool_use, server_tool_use) + if block_type in ("tool_use", "mcp_tool_use", "server_tool_use"): + # In content_block_start, input may already be present (empty dict) + block_input = getattr(block, "input", None) + current_tool_use = { + "id": getattr(block, "id", ""), + "name": getattr(block, "name", ""), + "input": block_input if block_input is not None else "", + "server_name": getattr(block, "server_name", None), + "is_server_call": block_type != "tool_use", + } + + # Handle tool result blocks + elif hasattr(block, "tool_use_id"): + tool_result = { + "id": getattr(block, "tool_use_id", ""), + "result": getattr(block, "content", ""), + } + return StreamEvent( + type="tool_result", data=tool_result + ), current_tool_use + + # For non-tool blocks (text, thinking), we don't need to track state + return None, current_tool_use + + elif event_type == "content_block_delta": + delta = getattr(event, "delta", None) + if not delta: + return None, current_tool_use + + delta_type = getattr(delta, "type", None) + + if delta_type == "text_delta": + text = getattr(delta, "text", "") + return StreamEvent(type="text", data=text), current_tool_use + + elif delta_type == "thinking_delta": + thinking = getattr(delta, "thinking", "") + return StreamEvent(type="thinking", data=thinking), current_tool_use + + elif delta_type == "signature_delta": + # Handle thinking signature for extended thinking + signature = getattr(delta, "signature", "") + return StreamEvent( + type="thinking", signature=signature + ), current_tool_use + + elif delta_type == "input_json_delta": + if current_tool_use is None: + # Edge case: received input_json_delta without tool_use start + logger.warning("Received input_json_delta without tool_use context") + return None, None + + # Only accumulate if input is still a string (being built up) + if isinstance(current_tool_use.get("input"), str): + partial_json = getattr(delta, "partial_json", "") + current_tool_use["input"] += partial_json + # else: input was already set as a dict in content_block_start + + return None, current_tool_use + + elif event_type == "content_block_stop": + if current_tool_use: + # Use the parsed input from the content block if available + # This handles empty inputs {} more reliably than parsing + content_block = getattr(event, "content_block", None) + if content_block and hasattr(content_block, "input"): + current_tool_use["input"] = content_block.input + else: + # Fallback: parse accumulated JSON string + input_str = current_tool_use.get("input", "") + if isinstance(input_str, str): + # Need to parse the accumulated string + if not input_str or input_str.isspace(): + # Empty or whitespace-only input + current_tool_use["input"] = {} + else: + try: + current_tool_use["input"] = json.loads(input_str) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse tool input '{input_str}': {e}" + ) + current_tool_use["input"] = {} + # else: input is already parsed + + tool_data = { + "id": current_tool_use.get("id", ""), + "name": current_tool_use.get("name", ""), + "input": current_tool_use.get("input", {}), + } + # Include server info if present + if current_tool_use.get("server_name"): + tool_data["server_name"] = current_tool_use["server_name"] + if current_tool_use.get("is_server_call"): + tool_data["is_server_call"] = current_tool_use["is_server_call"] + + return StreamEvent(type="tool_use", data=tool_data), None + + elif event_type == "message_delta": + delta = getattr(event, "delta", None) + if delta: + stop_reason = getattr(delta, "stop_reason", None) + if stop_reason == "max_tokens": + return StreamEvent( + type="error", data="Max tokens reached" + ), current_tool_use + + # Handle token usage information + usage = getattr(event, "usage", None) + if usage: + usage_data = { + "input_tokens": getattr(usage, "input_tokens", 0), + "output_tokens": getattr(usage, "output_tokens", 0), + "cache_creation_input_tokens": getattr( + usage, "cache_creation_input_tokens", None + ), + "cache_read_input_tokens": getattr( + usage, "cache_read_input_tokens", None + ), + } + # Could emit this as a separate event type if needed + logger.debug(f"Token usage: {usage_data}") + + return None, current_tool_use + + elif event_type == "message_stop": + # Final event - clean up any pending state + if current_tool_use: + logger.warning( + f"Message ended with incomplete tool use: {current_tool_use}" + ) + return StreamEvent(type="done"), None + + # Unknown event type - log but don't fail + if event_type and event_type not in ( + "message_start", + "message_delta", + "content_block_start", + "content_block_delta", + "content_block_stop", + "message_stop", + ): + logger.debug(f"Unknown event type: {event_type}") + + return None, current_tool_use + + def generate( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> str: + """Generate a non-streaming response.""" + settings = settings or LLMSettings() + kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + + try: + response = self.client.messages.create(**kwargs) + return "".join( + block.text for block in response.content if block.type == "text" + ) + except Exception as e: + logger.error(f"Anthropic API error: {e}") + raise + + def stream( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> Iterator[StreamEvent]: + """Generate a streaming response.""" + settings = settings or LLMSettings() + kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + + try: + with self.client.messages.stream(**kwargs) as stream: + current_tool_use: dict[str, Any] | None = None + + for event in stream: + stream_event, current_tool_use = self._handle_stream_event( + event, current_tool_use + ) + if stream_event: + yield stream_event + + except Exception as e: + logger.error(f"Anthropic streaming error: {e}") + yield StreamEvent(type="error", data=str(e)) + + async def agenerate( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> str: + """Generate a non-streaming response asynchronously.""" + settings = settings or LLMSettings() + kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + + try: + response = await self.async_client.messages.create(**kwargs) + return "".join( + block.text for block in response.content if block.type == "text" + ) + except Exception as e: + logger.error(f"Anthropic API error: {e}") + raise + + async def astream( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> AsyncIterator[StreamEvent]: + """Generate a streaming response asynchronously.""" + settings = settings or LLMSettings() + kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + + try: + async with self.async_client.messages.stream(**kwargs) as stream: + current_tool_use: dict[str, Any] | None = None + + async for event in stream: + stream_event, current_tool_use = self._handle_stream_event( + event, current_tool_use + ) + if stream_event: + yield stream_event + + except Exception as e: + logger.error(f"Anthropic streaming error: {e}") + yield StreamEvent(type="error", data=str(e)) + + def stream_with_tools( + self, + messages: list[Message], + tools: dict[str, ToolDefinition], + settings: LLMSettings | None = None, + system_prompt: str | None = None, + max_iterations: int = 10, + ) -> Iterator[StreamEvent]: + if max_iterations <= 0: + return + + response = TextContent(text="") + thinking = ThinkingContent(thinking="", signature="") + + for event in self.stream( + messages=messages, + system_prompt=system_prompt, + tools=list(tools.values()), + settings=settings, + ): + if event.type == "text": + response.text += event.data + yield event + elif event.type == "thinking" and event.signature: + thinking.signature = event.signature + elif event.type == "thinking": + thinking.thinking += event.data + yield event + elif event.type == "tool_use": + yield event + tool_result = self.execute_tool(event.data, tools) + yield StreamEvent(type="tool_result", data=tool_result.to_dict()) + messages.append( + Message.assistant( + response, + thinking, + ToolUseContent( + id=event.data["id"], + name=event.data["name"], + input=event.data["input"], + ), + ) + ) + messages.append(Message.user(tool_result=tool_result)) + yield from self.stream_with_tools( + messages, tools, settings, system_prompt, max_iterations - 1 + ) + elif event.type == "tool_result": + yield event + elif event.type == "error": + logger.error(f"LLM error: {event.data}") + raise RuntimeError(f"LLM error: {event.data}") diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py new file mode 100644 index 0000000..aceb7f8 --- /dev/null +++ b/src/memory/common/llms/base.py @@ -0,0 +1,561 @@ +"""Base classes and types for LLM providers.""" + +import base64 +import io +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union + +from PIL import Image + +from memory.common import settings +from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult + +logger = logging.getLogger(__name__) + + +class MessageRole(str, Enum): + """Message roles for chat history.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +@dataclass +class TextContent: + """Text content in a message.""" + + type: Literal["text"] = "text" + text: str = "" + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary format.""" + return {"type": "text", "text": self.text} + + +@dataclass +class ImageContent: + """Image content in a message.""" + + type: Literal["image"] = "image" + image: Image.Image = None # type: ignore + detail: Optional[str] = None # For OpenAI: "low", "high", "auto" + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary format.""" + # Note: Image will be encoded by provider-specific implementation + return {"type": "image", "image": self.image} + + +@dataclass +class ToolUseContent: + """Tool use request from the assistant.""" + + type: Literal["tool_use"] = "tool_use" + id: str = "" + name: str = "" + input: dict[str, Any] = None # type: ignore + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary format.""" + return { + "type": "tool_use", + "id": self.id, + "name": self.name, + "input": self.input, + } + + +@dataclass +class ToolResultContent: + """Tool result from tool execution.""" + + type: Literal["tool_result"] = "tool_result" + tool_use_id: str = "" + content: str = "" + is_error: bool = False + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary format.""" + return { + "type": "tool_result", + "tool_use_id": self.tool_use_id, + "content": self.content, + "is_error": self.is_error, + } + + +@dataclass +class ThinkingContent: + """Thinking/reasoning content from the assistant (extended thinking).""" + + type: Literal["thinking"] = "thinking" + thinking: str = "" + signature: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary format.""" + return { + "type": "thinking", + "thinking": self.thinking, + "signature": self.signature, + } + + +MessageContent = Union[ + TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent +] + + +@dataclass +class Turn: + """A turn in the conversation.""" + + response: str | None + thinking: str | None + tool_calls: dict[str, ToolResult] | None + + +@dataclass +class Message: + """A message in the conversation history.""" + + role: MessageRole + content: Union[str, list[MessageContent]] + + def to_dict(self) -> dict[str, Any]: + """Convert message to dictionary format.""" + if isinstance(self.content, str): + return {"role": self.role.value, "content": self.content} + content_list = [item.to_dict() for item in self.content] + return {"role": self.role.value, "content": content_list} + + @staticmethod + def assistant( + text: TextContent | None = None, + thinking: ThinkingContent | None = None, + tool_use: ToolUseContent | None = None, + ) -> "Message": + parts = [] + if text: + parts.append(text) + if thinking: + parts.append(thinking) + if tool_use: + parts.append(tool_use) + return Message(role=MessageRole.ASSISTANT, content=parts) + + @staticmethod + def user( + text: str | None = None, tool_result: ToolResultContent | None = None + ) -> "Message": + parts = [] + if text: + parts.append(TextContent(text=text)) + if tool_result: + parts.append(tool_result) + return Message(role=MessageRole.USER, content=parts) + + +@dataclass +class StreamEvent: + """An event from the streaming response.""" + + type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"] + data: Any = None + signature: str | None = None + + +@dataclass +class LLMSettings: + """Settings for LLM API calls.""" + + temperature: float = 0.7 + max_tokens: int = 2048 + # Don't set by default - some models don't allow both temp and top_p + top_p: float | None = None + stop_sequences: list[str] | None = None + stream: bool = False + + +class BaseLLMProvider(ABC): + """Base class for LLM providers.""" + + def __init__(self, api_key: str, model: str): + """ + Initialize the LLM provider. + + Args: + api_key: API key for the provider + model: Model identifier + """ + self.api_key = api_key + self.model = model + self._client: Any = None + + @abstractmethod + def _initialize_client(self) -> Any: + """Initialize the provider-specific client.""" + pass + + @property + def client(self) -> Any: + """Lazy-load the client.""" + if self._client is None: + self._client = self._initialize_client() + return self._client + + def execute_tool( + self, + tool_call: ToolCall, + tool_handlers: dict[str, ToolDefinition], + ) -> ToolResultContent: + """ + Execute a tool call. + + Args: + tool_call: Tool call + tool_handlers: Dict mapping tool names to handler functions + + Returns: + ToolResultContent with result or error + """ + name = tool_call.get("name") + tool_use_id = tool_call.get("id") + input = tool_call.get("input") + + if not name: + return ToolResultContent( + tool_use_id=tool_use_id, + content="Tool name missing", + is_error=True, + ) + + if not (tool := tool_handlers.get(name)): + return ToolResultContent( + tool_use_id=tool_use_id, + content=f"Tool '{name}' not found", + is_error=True, + ) + + try: + return ToolResultContent( + tool_use_id=tool_use_id, + content=tool(input), + is_error=False, + ) + except Exception as e: + logger.error(f"Tool '{name}' failed: {e}", exc_info=True) + return ToolResultContent( + tool_use_id=tool_use_id, + content=str(e), + is_error=True, + ) + + def encode_image(self, image: Image.Image) -> str: + """ + Encode PIL Image to base64 string. + + Args: + image: PIL Image to encode + + Returns: + Base64 encoded string + """ + buffer = io.BytesIO() + # Convert to RGB if necessary (for RGBA, etc.) + if image.mode != "RGB": + image = image.convert("RGB") + image.save(buffer, format="JPEG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def _convert_text_content(self, content: TextContent) -> dict[str, Any]: + """Convert TextContent to provider format. Override for custom format.""" + return content.to_dict() + + def _convert_image_content(self, content: ImageContent) -> dict[str, Any]: + """Convert ImageContent to provider format. Override for custom format.""" + return content.to_dict() + + def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]: + """Convert ToolUseContent to provider format. Override for custom format.""" + return content.to_dict() + + def _convert_tool_result_content( + self, content: ToolResultContent + ) -> dict[str, Any]: + """Convert ToolResultContent to provider format. Override for custom format.""" + return content.to_dict() + + def _convert_thinking_content(self, content: ThinkingContent) -> dict[str, Any]: + """Convert ThinkingContent to provider format. Override for custom format.""" + return content.to_dict() + + def _convert_message_content(self, content: MessageContent) -> dict[str, Any]: + """ + Convert a MessageContent item to provider format. + + Dispatches to type-specific converters that can be overridden. + """ + if isinstance(content, TextContent): + return self._convert_text_content(content) + elif isinstance(content, ImageContent): + return self._convert_image_content(content) + elif isinstance(content, ToolUseContent): + return self._convert_tool_use_content(content) + elif isinstance(content, ToolResultContent): + return self._convert_tool_result_content(content) + elif isinstance(content, ThinkingContent): + return self._convert_thinking_content(content) + else: + raise ValueError(f"Unknown content type: {type(content)}") + + def _convert_message(self, message: Message) -> dict[str, Any]: + """ + Convert a Message to provider format. + + Can be overridden for provider-specific handling (e.g., filtering system messages). + """ + return message.to_dict() + + def _should_include_message(self, message: Message) -> bool: + """ + Determine if a message should be included in the request. + + Override to filter messages (e.g., Anthropic filters SYSTEM messages). + + Args: + message: Message to check + + Returns: + True if message should be included + """ + return True + + def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]: + """ + Convert a list of messages to provider format. + + Uses _should_include_message for filtering and _convert_message for conversion. + """ + return [ + self._convert_message(msg) + for msg in messages + if self._should_include_message(msg) + ] + + def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]: + """ + Convert a single ToolDefinition to provider format. + + Default format matches Anthropic. Override for other providers (e.g., OpenAI uses functions). + """ + return { + "name": tool.name, + "description": tool.description, + "input_schema": tool.input_schema, + } + + def _convert_tools( + self, tools: list[ToolDefinition] | None + ) -> Optional[list[dict[str, Any]]]: + """Convert tool definitions to provider format.""" + if not tools: + return None + return [self._convert_tool(tool) for tool in tools] + + @abstractmethod + def generate( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> str: + """ + Generate a non-streaming response. + + Args: + messages: Conversation history + system_prompt: Optional system prompt + tools: Optional list of tools the LLM can use + settings: Optional settings for the generation + + Returns: + Generated text response + """ + pass + + @abstractmethod + def stream( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> Iterator[StreamEvent]: + """ + Generate a streaming response. + + Args: + messages: Conversation history + system_prompt: Optional system prompt + tools: Optional list of tools the LLM can use + settings: Optional settings for the generation + + Yields: + StreamEvent objects containing text chunks, tool uses, or errors + """ + pass + + @abstractmethod + async def agenerate( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> str: + """ + Generate a non-streaming response asynchronously. + + Args: + messages: Conversation history + system_prompt: Optional system prompt + tools: Optional list of tools the LLM can use + settings: Optional settings for the generation + + Returns: + Generated text response + """ + pass + + @abstractmethod + async def astream( + self, + messages: list[Message], + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, + ) -> AsyncIterator[StreamEvent]: + """ + Generate a streaming response asynchronously. + + Args: + messages: Conversation history + system_prompt: Optional system prompt + tools: Optional list of tools the LLM can use + settings: Optional settings for the generation + + Yields: + StreamEvent objects containing text chunks, tool uses, or errors + """ + pass + + @abstractmethod + def stream_with_tools( + self, + messages: list[Message], + tools: dict[str, ToolDefinition], + settings: LLMSettings | None = None, + system_prompt: str | None = None, + max_iterations: int = 10, + ) -> Iterator[StreamEvent]: + pass + + def run_with_tools( + self, + messages: list[Message], + tools: dict[str, ToolDefinition], + settings: LLMSettings | None = None, + system_prompt: str | None = None, + max_iterations: int = 10, + ) -> Turn: + thinking, response, tool_calls = "", "", {} + for event in self.stream_with_tools( + messages=messages, + tools=tools, + settings=settings, + system_prompt=system_prompt, + max_iterations=max_iterations, + ): + if event.type == "thinking": + thinking += event.data + elif event.type == "tool_use": + tool_calls[event.data["id"]] = { + "name": event.data["name"], + "input": event.data["input"], + "output": "", + } + elif event.type == "text": + response += event.data + elif event.type == "tool_result": + current = tool_calls.get(event.data["tool_use_id"]) or {} + tool_calls[event.data["tool_use_id"]] = { + "name": event.data.get("name") or current.get("name"), + "input": event.data.get("input") or current.get("input"), + "output": event.data.get("content"), + } + return Turn( + thinking=thinking or None, + response=response or None, + tool_calls=tool_calls or None, + ) + + +def create_provider( + model: str | None = None, + api_key: str | None = None, + enable_thinking: bool = False, +) -> BaseLLMProvider: + """ + Create an LLM provider based on the model name. + + Args: + model: Model identifier (e.g., "claude-3-opus-20240229", "gpt-4"). + If not provided, uses SUMMARIZER_MODEL from settings. + api_key: Optional API key. If not provided, uses keys from settings. + enable_thinking: Enable extended thinking for supported models (Claude Opus 4+, Sonnet 4+, Sonnet 3.7) + + Returns: + An initialized LLM provider + + Raises: + ValueError: If the provider cannot be determined from the model name + """ + # Use default model from settings if not provided + if model is None: + model = settings.SUMMARIZER_MODEL + + provider, model = model.split("/", 1) + + if provider == "anthropic": + # Anthropic models + if api_key is None: + api_key = settings.ANTHROPIC_API_KEY + + if not api_key: + raise ValueError( + "ANTHROPIC_API_KEY not found in settings. " + "Please set it in your .env file." + ) + + from memory.common.llms.anthropic_provider import AnthropicProvider + + return AnthropicProvider( + api_key=api_key, model=model, enable_thinking=enable_thinking + ) + + # Could add OpenAI support here in the future + # elif "gpt" in model_lower or model.startswith("openai"): + # ... + + else: + raise ValueError( + f"Unknown provider for model: {model}. " + f"Supported providers: Anthropic (claude-*)" + ) diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py new file mode 100644 index 0000000..2f230fb --- /dev/null +++ b/src/memory/common/llms/openai_provider.py @@ -0,0 +1,388 @@ +"""OpenAI LLM provider implementation.""" + +import logging +from typing import Any, AsyncIterator, Iterator, Optional + +import openai + +from memory.common.llms.base import ( + BaseLLMProvider, + ImageContent, + LLMSettings, + Message, + MessageContent, + MessageRole, + StreamEvent, + TextContent, + ThinkingContent, + ToolDefinition, + ToolResultContent, + ToolUseContent, +) + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseLLMProvider): + """OpenAI LLM provider with streaming and tool support.""" + + def _initialize_client(self) -> openai.OpenAI: + """Initialize the OpenAI client.""" + return openai.OpenAI(api_key=self.api_key) + + def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]: + """ + Convert our Message format to OpenAI format. + + Args: + messages: List of messages in our format + + Returns: + List of messages in OpenAI format + """ + openai_messages = [] + + for msg in messages: + if isinstance(msg.content, str): + openai_messages.append({"role": msg.role.value, "content": msg.content}) + else: + # Handle multi-part content + content_parts = [] + for item in msg.content: + if isinstance(item, TextContent): + content_parts.append({"type": "text", "text": item.text}) + elif isinstance(item, ImageContent): + encoded_image = self.encode_image(item.image) + image_part: dict[str, Any] = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{encoded_image}" + }, + } + if item.detail: + image_part["image_url"]["detail"] = item.detail + content_parts.append(image_part) + elif isinstance(item, ToolUseContent): + # OpenAI doesn't have tool_use in content, it's a separate field + # We'll handle this by adding a tool_calls field to the message + pass + elif isinstance(item, ToolResultContent): + # OpenAI handles tool results as separate "tool" role messages + openai_messages.append( + { + "role": "tool", + "tool_call_id": item.tool_use_id, + "content": item.content, + } + ) + continue + elif isinstance(item, ThinkingContent): + # OpenAI doesn't have native thinking support in most models + # We can add it as text with a special marker + content_parts.append( + { + "type": "text", + "text": f"[Thinking: {item.thinking}]", + } + ) + + # Check if this message has tool calls + tool_calls = [ + item for item in msg.content if isinstance(item, ToolUseContent) + ] + + message_dict: dict[str, Any] = {"role": msg.role.value} + + if content_parts: + message_dict["content"] = content_parts + + if tool_calls: + message_dict["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": str(tc.input)}, + } + for tc in tool_calls + ] + + openai_messages.append(message_dict) + + return openai_messages + + def _convert_tools( + self, tools: Optional[list[ToolDefinition]] + ) -> Optional[list[dict[str, Any]]]: + """ + Convert our tool definitions to OpenAI format. + + Args: + tools: List of tool definitions + + Returns: + List of tools in OpenAI format + """ + if not tools: + return None + + return [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + }, + } + for tool in tools + ] + + def generate( + self, + messages: list[Message], + system_prompt: Optional[str] = None, + tools: Optional[list[ToolDefinition]] = None, + settings: Optional[LLMSettings] = None, + ) -> str: + """Generate a non-streaming response.""" + settings = settings or LLMSettings() + + openai_messages = self._convert_messages(messages) + + # Add system prompt as first message if provided + if system_prompt: + openai_messages.insert( + 0, {"role": "system", "content": system_prompt} + ) + + kwargs: dict[str, Any] = { + "model": self.model, + "messages": openai_messages, + "temperature": settings.temperature, + "max_tokens": settings.max_tokens, + "top_p": settings.top_p, + } + + if settings.stop_sequences: + kwargs["stop"] = settings.stop_sequences + + if tools: + kwargs["tools"] = self._convert_tools(tools) + kwargs["tool_choice"] = "auto" + + try: + response = self.client.chat.completions.create(**kwargs) + return response.choices[0].message.content or "" + except Exception as e: + logger.error(f"OpenAI API error: {e}") + raise + + def stream( + self, + messages: list[Message], + system_prompt: Optional[str] = None, + tools: Optional[list[ToolDefinition]] = None, + settings: Optional[LLMSettings] = None, + ) -> Iterator[StreamEvent]: + """Generate a streaming response.""" + settings = settings or LLMSettings() + + openai_messages = self._convert_messages(messages) + + # Add system prompt as first message if provided + if system_prompt: + openai_messages.insert( + 0, {"role": "system", "content": system_prompt} + ) + + kwargs: dict[str, Any] = { + "model": self.model, + "messages": openai_messages, + "temperature": settings.temperature, + "max_tokens": settings.max_tokens, + "top_p": settings.top_p, + "stream": True, + } + + if settings.stop_sequences: + kwargs["stop"] = settings.stop_sequences + + if tools: + kwargs["tools"] = self._convert_tools(tools) + kwargs["tool_choice"] = "auto" + + try: + stream = self.client.chat.completions.create(**kwargs) + + current_tool_call: Optional[dict[str, Any]] = None + + for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + + # Handle text content + if delta.content: + yield StreamEvent(type="text", data=delta.content) + + # Handle tool calls + if delta.tool_calls: + for tool_call in delta.tool_calls: + if tool_call.id: + # New tool call starting + if current_tool_call: + # Yield the previous one + yield StreamEvent( + type="tool_use", data=current_tool_call + ) + current_tool_call = { + "id": tool_call.id, + "name": tool_call.function.name or "", + "arguments": tool_call.function.arguments or "", + } + elif current_tool_call and tool_call.function.arguments: + # Continue building the current tool call + current_tool_call["arguments"] += ( + tool_call.function.arguments + ) + + # Check if stream is finished + if chunk.choices[0].finish_reason: + if current_tool_call: + yield StreamEvent(type="tool_use", data=current_tool_call) + current_tool_call = None + + yield StreamEvent(type="done") + + except Exception as e: + logger.error(f"OpenAI streaming error: {e}") + yield StreamEvent(type="error", data=str(e)) + + async def agenerate( + self, + messages: list[Message], + system_prompt: Optional[str] = None, + tools: Optional[list[ToolDefinition]] = None, + settings: Optional[LLMSettings] = None, + ) -> str: + """Generate a non-streaming response asynchronously.""" + settings = settings or LLMSettings() + + # Use async client + async_client = openai.AsyncOpenAI(api_key=self.api_key) + + openai_messages = self._convert_messages(messages) + + # Add system prompt as first message if provided + if system_prompt: + openai_messages.insert( + 0, {"role": "system", "content": system_prompt} + ) + + kwargs: dict[str, Any] = { + "model": self.model, + "messages": openai_messages, + "temperature": settings.temperature, + "max_tokens": settings.max_tokens, + "top_p": settings.top_p, + } + + if settings.stop_sequences: + kwargs["stop"] = settings.stop_sequences + + if tools: + kwargs["tools"] = self._convert_tools(tools) + kwargs["tool_choice"] = "auto" + + try: + response = await async_client.chat.completions.create(**kwargs) + return response.choices[0].message.content or "" + except Exception as e: + logger.error(f"OpenAI API error: {e}") + raise + + async def astream( + self, + messages: list[Message], + system_prompt: Optional[str] = None, + tools: Optional[list[ToolDefinition]] = None, + settings: Optional[LLMSettings] = None, + ) -> AsyncIterator[StreamEvent]: + """Generate a streaming response asynchronously.""" + settings = settings or LLMSettings() + + # Use async client + async_client = openai.AsyncOpenAI(api_key=self.api_key) + + openai_messages = self._convert_messages(messages) + + # Add system prompt as first message if provided + if system_prompt: + openai_messages.insert( + 0, {"role": "system", "content": system_prompt} + ) + + kwargs: dict[str, Any] = { + "model": self.model, + "messages": openai_messages, + "temperature": settings.temperature, + "max_tokens": settings.max_tokens, + "top_p": settings.top_p, + "stream": True, + } + + if settings.stop_sequences: + kwargs["stop"] = settings.stop_sequences + + if tools: + kwargs["tools"] = self._convert_tools(tools) + kwargs["tool_choice"] = "auto" + + try: + stream = await async_client.chat.completions.create(**kwargs) + + current_tool_call: Optional[dict[str, Any]] = None + + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + + # Handle text content + if delta.content: + yield StreamEvent(type="text", data=delta.content) + + # Handle tool calls + if delta.tool_calls: + for tool_call in delta.tool_calls: + if tool_call.id: + # New tool call starting + if current_tool_call: + # Yield the previous one + yield StreamEvent( + type="tool_use", data=current_tool_call + ) + current_tool_call = { + "id": tool_call.id, + "name": tool_call.function.name or "", + "arguments": tool_call.function.arguments or "", + } + elif current_tool_call and tool_call.function.arguments: + # Continue building the current tool call + current_tool_call["arguments"] += ( + tool_call.function.arguments + ) + + # Check if stream is finished + if chunk.choices[0].finish_reason: + if current_tool_call: + yield StreamEvent(type="tool_use", data=current_tool_call) + current_tool_call = None + + yield StreamEvent(type="done") + + except Exception as e: + logger.error(f"OpenAI streaming error: {e}") + yield StreamEvent(type="error", data=str(e)) diff --git a/src/memory/common/llms/tools/__init__.py b/src/memory/common/llms/tools/__init__.py new file mode 100644 index 0000000..7a4e1c0 --- /dev/null +++ b/src/memory/common/llms/tools/__init__.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import Any, Callable, TypedDict + + +ToolInput = str | dict[str, Any] | None +ToolHandler = Callable[[ToolInput], str] + + +class ToolCall(TypedDict): + """A call to a tool.""" + + name: str + id: str + input: ToolInput + + +class ToolResult(TypedDict): + """A result from a tool call.""" + + id: str + name: str + input: ToolInput + output: str + + +@dataclass +class ToolDefinition: + """Definition of a tool that can be called by the LLM.""" + + name: str + description: str + input_schema: dict[str, Any] # JSON Schema for the tool's parameters + function: ToolHandler + + def __call__(self, input: ToolInput) -> str: + return self.function(input) diff --git a/src/memory/common/llms/tools/ping.py b/src/memory/common/llms/tools/ping.py new file mode 100644 index 0000000..5ec2ea6 --- /dev/null +++ b/src/memory/common/llms/tools/ping.py @@ -0,0 +1,42 @@ +"""Ping tool for testing LLM tool integration.""" + +from memory.common.llms.tools import ToolDefinition, ToolInput + + +def handle_ping_call(message: ToolInput = None) -> str: + """ + Handle a ping tool call. + + Args: + message: Optional message to include in response + + Returns: + Response string + """ + if message: + return f"pong: {message}" + return "pong" + + +def get_ping_tool() -> ToolDefinition: + """ + Get a ping tool definition for testing tool calls. + + Returns a simple tool that takes no required parameters and can be used + to verify that tool calling is working correctly. + """ + return ToolDefinition( + name="ping", + description="A simple test tool that returns 'pong'. Use this to verify tool calling is working.", + input_schema={ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Optional message to echo back", + } + }, + "required": [], + }, + function=handle_ping_call, + ) diff --git a/src/memory/common/messages.py b/src/memory/common/messages.py new file mode 100644 index 0000000..1244de3 --- /dev/null +++ b/src/memory/common/messages.py @@ -0,0 +1,9 @@ +def process_message( + msg: str, + history: list[str], + model: str | None = None, + system_prompt: str | None = None, + allowed_tools: list[str] | None = None, + disallowed_tools: list[str] | None = None, +) -> str: + return "asd" diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 793e698..daf98bb 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -172,6 +172,7 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat") DISCORD_NOTIFICATIONS_ENABLED = bool( boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN ) +DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True) # Discord collector settings DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True) diff --git a/src/memory/common/summarizer.py b/src/memory/common/summarizer.py index c2016f3..b2d2dbc 100644 --- a/src/memory/common/summarizer.py +++ b/src/memory/common/summarizer.py @@ -105,7 +105,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list prompt = llms.truncate(prompt, MAX_TOKENS - 20) try: - response = llms.call(prompt, settings.SUMMARIZER_MODEL) + response = llms.summarize(prompt, settings.SUMMARIZER_MODEL) result = parse_response(response) summary = result.get("summary", "") diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index f42ff76..f60a22f 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -160,7 +160,7 @@ def should_track_message( return False if channel.channel_type in ("dm", "group_dm"): - return bool(user.allow_dm_tracking) + return bool(user.track_messages) # Default: track the message return True diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index 764d4f2..093acd1 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -16,7 +16,11 @@ from memory.workers.tasks.content_processing import ( create_task_result, process_content_item, ) -from memory.common.celery_app import ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE +from memory.common.celery_app import ( + ADD_DISCORD_MESSAGE, + EDIT_DISCORD_MESSAGE, + PROCESS_DISCORD_MESSAGE, +) from memory.common import settings from sqlalchemy.orm import Session, scoped_session @@ -40,6 +44,41 @@ def get_prev( return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] +def should_process(message: DiscordMessage) -> bool: + return ( + settings.DISCORD_PROCESS_MESSAGES + and settings.DISCORD_NOTIFICATIONS_ENABLED + and not ( + (message.server and message.server.ignore_messages) + or (message.channel and message.channel.ignore_messages) + or (message.discord_user and message.discord_user.ignore_messages) + ) + ) + + +@app.task(name=PROCESS_DISCORD_MESSAGE) +@safe_task_execution +def process_discord_message(message_id: int) -> dict[str, Any]: + logger.info(f"Processing Discord message {message_id}") + + with make_session() as session: + discord_message = session.query(DiscordMessage).get(message_id) + if not discord_message: + logger.info(f"Discord message not found: {message_id}") + return { + "status": "error", + "error": "Message not found", + "message_id": message_id, + } + + print("Processing message", discord_message.id, discord_message.content) + + return { + "status": "processed", + "message_id": message_id, + } + + @app.task(name=ADD_DISCORD_MESSAGE) @safe_task_execution def add_discord_message( @@ -87,11 +126,8 @@ def add_discord_message( discord_message.messages_before = get_prev(session, channel_id, sent_at_dt) result = process_content_item(discord_message, session) - - logger.info( - f"Discord message ID after process_content_item: {discord_message.id}" - ) - logger.info(f"Process result: {result}") + if should_process(discord_message): + process_discord_message.delay(discord_message.id) return result diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index 200cd2a..3248285 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -88,11 +88,10 @@ def execute_scheduled_call(self, scheduled_call_id: str): # Make the LLM call if scheduled_call.model: - response = llms.call( + response = llms.summarize( prompt=cast(str, scheduled_call.message), model=cast(str, scheduled_call.model), - system_prompt=cast(str, scheduled_call.system_prompt) - or llms.SYSTEM_PROMPT, + system_prompt=cast(str, scheduled_call.system_prompt), ) else: response = cast(str, scheduled_call.message) diff --git a/tests/conftest.py b/tests/conftest.py index 92166cb..175eedf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -273,6 +273,27 @@ def mock_anthropic_client(): with patch.object(anthropic, "Anthropic", autospec=True) as mock_client: client = mock_client() client.messages = Mock() + + # Mock stream as a context manager + mock_stream = Mock() + mock_stream.__enter__ = Mock( + return_value=Mock( + __iter__=lambda self: iter( + [ + Mock( + type="content_block_delta", + delta=Mock( + type="text_delta", + text="test summarytag1tag2", + ), + ) + ] + ) + ) + ) + mock_stream.__exit__ = Mock(return_value=False) + client.messages.stream = Mock(return_value=mock_stream) + client.messages.create = Mock( return_value=Mock( content=[ diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py index 6624b6d..3d46a0d 100644 --- a/tests/memory/common/test_discord.py +++ b/tests/memory/common/test_discord.py @@ -2,318 +2,250 @@ import pytest from unittest.mock import Mock, patch import requests -from memory.common import discord, settings +from memory.common import discord @pytest.fixture -def mock_session_request(): - with patch("requests.Session.request") as mock: - yield mock +def mock_api_url(): + """Mock the API URL to avoid using actual settings""" + with patch( + "memory.common.discord.get_api_url", return_value="http://localhost:8000" + ): + yield -@pytest.fixture -def mock_get_channels_response(): - return [ - {"name": "memory-errors", "id": "error_channel_id"}, - {"name": "memory-activity", "id": "activity_channel_id"}, - {"name": "memory-discoveries", "id": "discovery_channel_id"}, - {"name": "memory-chat", "id": "chat_channel_id"}, - ] +@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost") +@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999) +def test_get_api_url(): + """Test API URL construction""" + assert discord.get_api_url() == "http://testhost:9999" -def test_discord_server_init(mock_session_request, mock_get_channels_response): - # Mock the channels API call +@patch("requests.post") +def test_send_dm_success(mock_post, mock_api_url): + """Test successful DM sending""" mock_response = Mock() - mock_response.json.return_value = mock_get_channels_response + mock_response.json.return_value = {"success": True} mock_response.raise_for_status.return_value = None - mock_session_request.return_value = mock_response + mock_post.return_value = mock_response - server = discord.DiscordServer("server123", "Test Server") + result = discord.send_dm("user123", "Hello!") - assert server.server_id == "server123" - assert server.server_name == "Test Server" - assert hasattr(server, "channels") - - -@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors") -@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity") -@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries") -@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat") -def test_setup_channels_existing(mock_session_request, mock_get_channels_response): - # Mock the channels API call - mock_response = Mock() - mock_response.json.return_value = mock_get_channels_response - mock_response.raise_for_status.return_value = None - mock_session_request.return_value = mock_response - - server = discord.DiscordServer("server123", "Test Server") - - assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id" - assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id" - assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id" - assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id" - - -@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel") -def test_setup_channels_create_missing(mock_session_request): - # Mock get channels (empty) and create channel calls - get_response = Mock() - get_response.json.return_value = [] - get_response.raise_for_status.return_value = None - - create_response = Mock() - create_response.json.return_value = {"id": "new_channel_id"} - create_response.raise_for_status.return_value = None - - mock_session_request.side_effect = [ - get_response, - create_response, - create_response, - create_response, - create_response, - ] - - server = discord.DiscordServer("server123", "Test Server") - - assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id" - - -def test_channel_properties(): - server = discord.DiscordServer.__new__(discord.DiscordServer) - server.channels = { - discord.ERROR_CHANNEL: "error_id", - discord.ACTIVITY_CHANNEL: "activity_id", - discord.DISCOVERY_CHANNEL: "discovery_id", - discord.CHAT_CHANNEL: "chat_id", - } - - assert server.error_channel == "error_id" - assert server.activity_channel == "activity_id" - assert server.discovery_channel == "discovery_id" - assert server.chat_channel == "chat_id" - - -def test_channel_id_exists(): - server = discord.DiscordServer.__new__(discord.DiscordServer) - server.channels = {"test-channel": "channel123"} - - assert server.channel_id("test-channel") == "channel123" - - -def test_channel_id_not_found(): - server = discord.DiscordServer.__new__(discord.DiscordServer) - server.channels = {} - - with pytest.raises(ValueError, match="Channel nonexistent not found"): - server.channel_id("nonexistent") - - -def test_send_message(mock_session_request): - mock_response = Mock() - mock_response.raise_for_status.return_value = None - mock_session_request.return_value = mock_response - - server = discord.DiscordServer.__new__(discord.DiscordServer) - - server.send_message("channel123", "Hello World") - - mock_session_request.assert_called_with( - "POST", - "https://discord.com/api/v10/channels/channel123/messages", - data=None, - json={"content": "Hello World"}, - headers={ - "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", - "Content-Type": "application/json", - }, + assert result is True + mock_post.assert_called_once_with( + "http://localhost:8000/send_dm", + json={"user_identifier": "user123", "message": "Hello!"}, + timeout=10, ) -def test_create_channel(mock_session_request): +@patch("requests.post") +def test_send_dm_api_failure(mock_post, mock_api_url): + """Test DM sending when API returns failure""" mock_response = Mock() - mock_response.json.return_value = {"id": "new_channel_id"} + mock_response.json.return_value = {"success": False} mock_response.raise_for_status.return_value = None - mock_session_request.return_value = mock_response + mock_post.return_value = mock_response - server = discord.DiscordServer.__new__(discord.DiscordServer) - server.server_id = "server123" + result = discord.send_dm("user123", "Hello!") - channel_id = server.create_channel("new-channel") - - assert channel_id == "new_channel_id" - mock_session_request.assert_called_with( - "POST", - "https://discord.com/api/v10/guilds/server123/channels", - data=None, - json={"name": "new-channel", "type": 0}, - headers={ - "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", - "Content-Type": "application/json", - }, - ) + assert result is False -def test_create_channel_custom_type(mock_session_request): +@patch("requests.post") +def test_send_dm_request_exception(mock_post, mock_api_url): + """Test DM sending when request raises exception""" + mock_post.side_effect = requests.RequestException("Network error") + + result = discord.send_dm("user123", "Hello!") + + assert result is False + + +@patch("requests.post") +def test_send_dm_http_error(mock_post, mock_api_url): + """Test DM sending when HTTP error occurs""" mock_response = Mock() - mock_response.json.return_value = {"id": "voice_channel_id"} + mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") + mock_post.return_value = mock_response + + result = discord.send_dm("user123", "Hello!") + + assert result is False + + +@patch("requests.post") +def test_broadcast_message_success(mock_post, mock_api_url): + """Test successful channel message broadcast""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} mock_response.raise_for_status.return_value = None - mock_session_request.return_value = mock_response + mock_post.return_value = mock_response - server = discord.DiscordServer.__new__(discord.DiscordServer) - server.server_id = "server123" + result = discord.broadcast_message("general", "Announcement!") - channel_id = server.create_channel("voice-channel", channel_type=2) - - assert channel_id == "voice_channel_id" - mock_session_request.assert_called_with( - "POST", - "https://discord.com/api/v10/guilds/server123/channels", - data=None, - json={"name": "voice-channel", "type": 2}, - headers={ - "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", - "Content-Type": "application/json", - }, + assert result is True + mock_post.assert_called_once_with( + "http://localhost:8000/send_channel", + json={"channel_name": "general", "message": "Announcement!"}, + timeout=10, ) -def test_str_representation(): - server = discord.DiscordServer.__new__(discord.DiscordServer) - server.server_id = "server123" - server.server_name = "Test Server" +@patch("requests.post") +def test_broadcast_message_failure(mock_post, mock_api_url): + """Test channel message broadcast failure""" + mock_response = Mock() + mock_response.json.return_value = {"success": False} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response - assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)" + result = discord.broadcast_message("general", "Announcement!") + + assert result is False -@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123") -def test_request_adds_headers(mock_session_request): - server = discord.DiscordServer.__new__(discord.DiscordServer) +@patch("requests.post") +def test_broadcast_message_exception(mock_post, mock_api_url): + """Test channel message broadcast with exception""" + mock_post.side_effect = requests.Timeout("Request timeout") - server.request("GET", "https://example.com", headers={"Custom": "header"}) + result = discord.broadcast_message("general", "Announcement!") - expected_headers = { - "Custom": "header", - "Authorization": "Bot test_token_123", - "Content-Type": "application/json", - } - mock_session_request.assert_called_once_with( - "GET", "https://example.com", headers=expected_headers - ) + assert result is False -def test_channels_url(): - server = discord.DiscordServer.__new__(discord.DiscordServer) - server.server_id = "server123" - - assert ( - server.channels_url == "https://discord.com/api/v10/guilds/server123/channels" - ) - - -@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token") @patch("requests.get") -def test_get_bot_servers_success(mock_get): +def test_is_collector_healthy_true(mock_get, mock_api_url): + """Test health check when collector is healthy""" mock_response = Mock() - mock_response.json.return_value = [ - {"id": "server1", "name": "Server 1"}, - {"id": "server2", "name": "Server 2"}, - ] + mock_response.json.return_value = {"status": "healthy"} mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - servers = discord.get_bot_servers() + result = discord.is_collector_healthy() - assert len(servers) == 2 - assert servers[0] == {"id": "server1", "name": "Server 1"} - mock_get.assert_called_once_with( - "https://discord.com/api/v10/users/@me/guilds", - headers={"Authorization": "Bot test_token"}, - ) + assert result is True + mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) -@patch("memory.common.settings.DISCORD_BOT_TOKEN", None) -def test_get_bot_servers_no_token(): - assert discord.get_bot_servers() == [] - - -@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token") @patch("requests.get") -def test_get_bot_servers_exception(mock_get): - mock_get.side_effect = requests.RequestException("API Error") +def test_is_collector_healthy_false_status(mock_get, mock_api_url): + """Test health check when collector returns unhealthy status""" + mock_response = Mock() + mock_response.json.return_value = {"status": "unhealthy"} + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response - servers = discord.get_bot_servers() + result = discord.is_collector_healthy() - assert servers == [] + assert result is False -@patch("memory.common.discord.get_bot_servers") -@patch("memory.common.discord.DiscordServer") -def test_load_servers(mock_discord_server_class, mock_get_servers): - mock_get_servers.return_value = [ - {"id": "server1", "name": "Server 1"}, - {"id": "server2", "name": "Server 2"}, - ] +@patch("requests.get") +def test_is_collector_healthy_exception(mock_get, mock_api_url): + """Test health check when request fails""" + mock_get.side_effect = requests.ConnectionError("Connection refused") - discord.load_servers() + result = discord.is_collector_healthy() - assert mock_discord_server_class.call_count == 2 - mock_discord_server_class.assert_any_call("server1", "Server 1") - mock_discord_server_class.assert_any_call("server2", "Server 2") + assert result is False -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) -def test_broadcast_message(): - mock_server1 = Mock() - mock_server2 = Mock() - discord.servers = {"1": mock_server1, "2": mock_server2} +@patch("requests.post") +def test_refresh_discord_metadata_success(mock_post, mock_api_url): + """Test successful metadata refresh""" + mock_response = Mock() + mock_response.json.return_value = { + "servers": 5, + "channels": 20, + "users": 100, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response - discord.broadcast_message("test-channel", "Hello") + result = discord.refresh_discord_metadata() - mock_server1.send_message.assert_called_once_with( - mock_server1.channel_id.return_value, "Hello" - ) - mock_server2.send_message.assert_called_once_with( - mock_server2.channel_id.return_value, "Hello" + assert result == {"servers": 5, "channels": 20, "users": 100} + mock_post.assert_called_once_with( + "http://localhost:8000/refresh_metadata", timeout=30 ) -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) -def test_broadcast_message_disabled(): - mock_server = Mock() - discord.servers = {"1": mock_server} +@patch("requests.post") +def test_refresh_discord_metadata_failure(mock_post, mock_api_url): + """Test metadata refresh failure""" + mock_post.side_effect = requests.RequestException("Failed to connect") - discord.broadcast_message("test-channel", "Hello") + result = discord.refresh_discord_metadata() - mock_server.send_message.assert_not_called() + assert result is None + + +@patch("requests.post") +def test_refresh_discord_metadata_http_error(mock_post, mock_api_url): + """Test metadata refresh with HTTP error""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error") + mock_post.return_value = mock_response + + result = discord.refresh_discord_metadata() + + assert result is None @patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors") def test_send_error_message(mock_broadcast): - discord.send_error_message("Error occurred") - mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred") + """Test sending error message to error channel""" + mock_broadcast.return_value = True + + result = discord.send_error_message("Something broke") + + assert result is True + mock_broadcast.assert_called_once_with("errors", "Something broke") @patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity") def test_send_activity_message(mock_broadcast): - discord.send_activity_message("Activity update") - mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update") + """Test sending activity message to activity channel""" + mock_broadcast.return_value = True + + result = discord.send_activity_message("User logged in") + + assert result is True + mock_broadcast.assert_called_once_with("activity", "User logged in") @patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries") def test_send_discovery_message(mock_broadcast): - discord.send_discovery_message("Discovery made") - mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made") + """Test sending discovery message to discovery channel""" + mock_broadcast.return_value = True + + result = discord.send_discovery_message("Found interesting pattern") + + assert result is True + mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") @patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat") def test_send_chat_message(mock_broadcast): - discord.send_chat_message("Chat message") - mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message") + """Test sending chat message to chat channel""" + mock_broadcast.return_value = True + + result = discord.send_chat_message("Hello from bot") + + assert result is True + mock_broadcast.assert_called_once_with("chat", "Hello from bot") -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) def test_notify_task_failure_basic(mock_send_error): + """Test basic task failure notification""" discord.notify_task_failure("test_task", "Something went wrong") mock_send_error.assert_called_once() @@ -323,69 +255,181 @@ def test_notify_task_failure_basic(mock_send_error): assert "**Error:** Something went wrong" in message -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) def test_notify_task_failure_with_args(mock_send_error): + """Test task failure notification with arguments""" discord.notify_task_failure( "test_task", - "Error message", - task_args=("arg1", "arg2"), - task_kwargs={"key": "value"}, + "Error occurred", + task_args=("arg1", 42), + task_kwargs={"key": "value", "number": 123}, ) message = mock_send_error.call_args[0][0] - assert "**Args:** `('arg1', 'arg2')`" in message - assert "**Kwargs:** `{'key': 'value'}`" in message + assert "**Args:** `('arg1', 42)" in message + assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) def test_notify_task_failure_with_traceback(mock_send_error): - traceback = "Traceback (most recent call last):\n File ...\nError: Something" + """Test task failure notification with traceback""" + traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" - discord.notify_task_failure("test_task", "Error message", traceback_str=traceback) + discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) message = mock_send_error.call_args[0][0] + assert "**Traceback:**" in message - assert traceback in message + assert "Exception: test" in message -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) def test_notify_task_failure_truncates_long_error(mock_send_error): - long_error = "x" * 600 # Longer than 500 char limit + """Test that long error messages are truncated""" + long_error = "x" * 600 discord.notify_task_failure("test_task", long_error) message = mock_send_error.call_args[0][0] - assert long_error[:500] in message + + # Error should be truncated to 500 chars - check that the full 600 char string is not there + assert "**Error:** " + long_error[:500] in message + # The full 600-char error should not be present + error_section = message.split("**Error:** ")[1].split("\n")[0] + assert len(error_section) == 500 -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) def test_notify_task_failure_truncates_long_traceback(mock_send_error): - long_traceback = "x" * 1000 # Longer than 800 char limit + """Test that long tracebacks are truncated""" + long_traceback = "x" * 1000 discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) message = mock_send_error.call_args[0][0] + + # Traceback should show last 800 chars assert long_traceback[-800:] in message + # The full 1000-char traceback should not be present + traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0] + assert len(traceback_section) == 800 -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) @patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_truncates_long_args(mock_send_error): + """Test that long task arguments are truncated""" + long_args = ("x" * 300,) + + discord.notify_task_failure("test_task", "Error", task_args=long_args) + + message = mock_send_error.call_args[0][0] + + # Args should be truncated to 200 chars + assert ( + len(message.split("**Args:**")[1].split("\n")[0]) <= 210 + ) # Some buffer for formatting + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_truncates_long_kwargs(mock_send_error): + """Test that long task kwargs are truncated""" + long_kwargs = {"key": "x" * 300} + + discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) + + message = mock_send_error.call_args[0][0] + + # Kwargs should be truncated to 200 chars + assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) def test_notify_task_failure_disabled(mock_send_error): - discord.notify_task_failure("test_task", "Error message") + """Test that notifications are not sent when disabled""" + discord.notify_task_failure("test_task", "Error occurred") + mock_send_error.assert_not_called() -@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.discord.send_error_message") -def test_notify_task_failure_send_fails(mock_send_error): - mock_send_error.side_effect = Exception("Discord API error") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_send_error_exception(mock_send_error): + """Test that exceptions in send_error_message don't propagate""" + mock_send_error.side_effect = Exception("Failed to send") - # Should not raise, just log the error - discord.notify_task_failure("test_task", "Error message") + # Should not raise + discord.notify_task_failure("test_task", "Error occurred") mock_send_error.assert_called_once() + + +@pytest.mark.parametrize( + "function,channel_setting,message", + [ + (discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"), + (discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"), + (discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"), + (discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"), + ], +) +@patch("memory.common.discord.broadcast_message") +def test_convenience_functions_use_correct_channels( + mock_broadcast, function, channel_setting, message +): + """Test that convenience functions use the correct channel settings""" + with patch(f"memory.common.settings.{channel_setting}", "test-channel"): + function(message) + mock_broadcast.assert_called_once_with("test-channel", message) + + +@patch("requests.post") +def test_send_dm_with_special_characters(mock_post, mock_api_url): + """Test sending DM with special characters""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + message_with_special_chars = "Hello! 🎉 <@123> #general" + result = discord.send_dm("user123", message_with_special_chars) + + assert result is True + call_args = mock_post.call_args + assert call_args[1]["json"]["message"] == message_with_special_chars + + +@patch("requests.post") +def test_broadcast_message_with_long_message(mock_post, mock_api_url): + """Test broadcasting a long message""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + long_message = "A" * 2000 + result = discord.broadcast_message("general", long_message) + + assert result is True + call_args = mock_post.call_args + assert call_args[1]["json"]["message"] == long_message + + +@patch("requests.get") +def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url): + """Test health check when response doesn't have status key""" + mock_response = Mock() + mock_response.json.return_value = {} + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + result = discord.is_collector_healthy() + + assert result is False diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py new file mode 100644 index 0000000..3d46a0d --- /dev/null +++ b/tests/memory/common/test_discord_integration.py @@ -0,0 +1,435 @@ +import pytest +from unittest.mock import Mock, patch +import requests + +from memory.common import discord + + +@pytest.fixture +def mock_api_url(): + """Mock the API URL to avoid using actual settings""" + with patch( + "memory.common.discord.get_api_url", return_value="http://localhost:8000" + ): + yield + + +@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost") +@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999) +def test_get_api_url(): + """Test API URL construction""" + assert discord.get_api_url() == "http://testhost:9999" + + +@patch("requests.post") +def test_send_dm_success(mock_post, mock_api_url): + """Test successful DM sending""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = discord.send_dm("user123", "Hello!") + + assert result is True + mock_post.assert_called_once_with( + "http://localhost:8000/send_dm", + json={"user_identifier": "user123", "message": "Hello!"}, + timeout=10, + ) + + +@patch("requests.post") +def test_send_dm_api_failure(mock_post, mock_api_url): + """Test DM sending when API returns failure""" + mock_response = Mock() + mock_response.json.return_value = {"success": False} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = discord.send_dm("user123", "Hello!") + + assert result is False + + +@patch("requests.post") +def test_send_dm_request_exception(mock_post, mock_api_url): + """Test DM sending when request raises exception""" + mock_post.side_effect = requests.RequestException("Network error") + + result = discord.send_dm("user123", "Hello!") + + assert result is False + + +@patch("requests.post") +def test_send_dm_http_error(mock_post, mock_api_url): + """Test DM sending when HTTP error occurs""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") + mock_post.return_value = mock_response + + result = discord.send_dm("user123", "Hello!") + + assert result is False + + +@patch("requests.post") +def test_broadcast_message_success(mock_post, mock_api_url): + """Test successful channel message broadcast""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = discord.broadcast_message("general", "Announcement!") + + assert result is True + mock_post.assert_called_once_with( + "http://localhost:8000/send_channel", + json={"channel_name": "general", "message": "Announcement!"}, + timeout=10, + ) + + +@patch("requests.post") +def test_broadcast_message_failure(mock_post, mock_api_url): + """Test channel message broadcast failure""" + mock_response = Mock() + mock_response.json.return_value = {"success": False} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = discord.broadcast_message("general", "Announcement!") + + assert result is False + + +@patch("requests.post") +def test_broadcast_message_exception(mock_post, mock_api_url): + """Test channel message broadcast with exception""" + mock_post.side_effect = requests.Timeout("Request timeout") + + result = discord.broadcast_message("general", "Announcement!") + + assert result is False + + +@patch("requests.get") +def test_is_collector_healthy_true(mock_get, mock_api_url): + """Test health check when collector is healthy""" + mock_response = Mock() + mock_response.json.return_value = {"status": "healthy"} + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + result = discord.is_collector_healthy() + + assert result is True + mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) + + +@patch("requests.get") +def test_is_collector_healthy_false_status(mock_get, mock_api_url): + """Test health check when collector returns unhealthy status""" + mock_response = Mock() + mock_response.json.return_value = {"status": "unhealthy"} + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + result = discord.is_collector_healthy() + + assert result is False + + +@patch("requests.get") +def test_is_collector_healthy_exception(mock_get, mock_api_url): + """Test health check when request fails""" + mock_get.side_effect = requests.ConnectionError("Connection refused") + + result = discord.is_collector_healthy() + + assert result is False + + +@patch("requests.post") +def test_refresh_discord_metadata_success(mock_post, mock_api_url): + """Test successful metadata refresh""" + mock_response = Mock() + mock_response.json.return_value = { + "servers": 5, + "channels": 20, + "users": 100, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = discord.refresh_discord_metadata() + + assert result == {"servers": 5, "channels": 20, "users": 100} + mock_post.assert_called_once_with( + "http://localhost:8000/refresh_metadata", timeout=30 + ) + + +@patch("requests.post") +def test_refresh_discord_metadata_failure(mock_post, mock_api_url): + """Test metadata refresh failure""" + mock_post.side_effect = requests.RequestException("Failed to connect") + + result = discord.refresh_discord_metadata() + + assert result is None + + +@patch("requests.post") +def test_refresh_discord_metadata_http_error(mock_post, mock_api_url): + """Test metadata refresh with HTTP error""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error") + mock_post.return_value = mock_response + + result = discord.refresh_discord_metadata() + + assert result is None + + +@patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors") +def test_send_error_message(mock_broadcast): + """Test sending error message to error channel""" + mock_broadcast.return_value = True + + result = discord.send_error_message("Something broke") + + assert result is True + mock_broadcast.assert_called_once_with("errors", "Something broke") + + +@patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity") +def test_send_activity_message(mock_broadcast): + """Test sending activity message to activity channel""" + mock_broadcast.return_value = True + + result = discord.send_activity_message("User logged in") + + assert result is True + mock_broadcast.assert_called_once_with("activity", "User logged in") + + +@patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries") +def test_send_discovery_message(mock_broadcast): + """Test sending discovery message to discovery channel""" + mock_broadcast.return_value = True + + result = discord.send_discovery_message("Found interesting pattern") + + assert result is True + mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") + + +@patch("memory.common.discord.broadcast_message") +@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat") +def test_send_chat_message(mock_broadcast): + """Test sending chat message to chat channel""" + mock_broadcast.return_value = True + + result = discord.send_chat_message("Hello from bot") + + assert result is True + mock_broadcast.assert_called_once_with("chat", "Hello from bot") + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_basic(mock_send_error): + """Test basic task failure notification""" + discord.notify_task_failure("test_task", "Something went wrong") + + mock_send_error.assert_called_once() + message = mock_send_error.call_args[0][0] + + assert "🚨 **Task Failed: test_task**" in message + assert "**Error:** Something went wrong" in message + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_with_args(mock_send_error): + """Test task failure notification with arguments""" + discord.notify_task_failure( + "test_task", + "Error occurred", + task_args=("arg1", 42), + task_kwargs={"key": "value", "number": 123}, + ) + + message = mock_send_error.call_args[0][0] + + assert "**Args:** `('arg1', 42)" in message + assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_with_traceback(mock_send_error): + """Test task failure notification with traceback""" + traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" + + discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) + + message = mock_send_error.call_args[0][0] + + assert "**Traceback:**" in message + assert "Exception: test" in message + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_truncates_long_error(mock_send_error): + """Test that long error messages are truncated""" + long_error = "x" * 600 + + discord.notify_task_failure("test_task", long_error) + + message = mock_send_error.call_args[0][0] + + # Error should be truncated to 500 chars - check that the full 600 char string is not there + assert "**Error:** " + long_error[:500] in message + # The full 600-char error should not be present + error_section = message.split("**Error:** ")[1].split("\n")[0] + assert len(error_section) == 500 + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_truncates_long_traceback(mock_send_error): + """Test that long tracebacks are truncated""" + long_traceback = "x" * 1000 + + discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) + + message = mock_send_error.call_args[0][0] + + # Traceback should show last 800 chars + assert long_traceback[-800:] in message + # The full 1000-char traceback should not be present + traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0] + assert len(traceback_section) == 800 + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_truncates_long_args(mock_send_error): + """Test that long task arguments are truncated""" + long_args = ("x" * 300,) + + discord.notify_task_failure("test_task", "Error", task_args=long_args) + + message = mock_send_error.call_args[0][0] + + # Args should be truncated to 200 chars + assert ( + len(message.split("**Args:**")[1].split("\n")[0]) <= 210 + ) # Some buffer for formatting + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_truncates_long_kwargs(mock_send_error): + """Test that long task kwargs are truncated""" + long_kwargs = {"key": "x" * 300} + + discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) + + message = mock_send_error.call_args[0][0] + + # Kwargs should be truncated to 200 chars + assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) +def test_notify_task_failure_disabled(mock_send_error): + """Test that notifications are not sent when disabled""" + discord.notify_task_failure("test_task", "Error occurred") + + mock_send_error.assert_not_called() + + +@patch("memory.common.discord.send_error_message") +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_notify_task_failure_send_error_exception(mock_send_error): + """Test that exceptions in send_error_message don't propagate""" + mock_send_error.side_effect = Exception("Failed to send") + + # Should not raise + discord.notify_task_failure("test_task", "Error occurred") + + mock_send_error.assert_called_once() + + +@pytest.mark.parametrize( + "function,channel_setting,message", + [ + (discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"), + (discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"), + (discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"), + (discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"), + ], +) +@patch("memory.common.discord.broadcast_message") +def test_convenience_functions_use_correct_channels( + mock_broadcast, function, channel_setting, message +): + """Test that convenience functions use the correct channel settings""" + with patch(f"memory.common.settings.{channel_setting}", "test-channel"): + function(message) + mock_broadcast.assert_called_once_with("test-channel", message) + + +@patch("requests.post") +def test_send_dm_with_special_characters(mock_post, mock_api_url): + """Test sending DM with special characters""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + message_with_special_chars = "Hello! 🎉 <@123> #general" + result = discord.send_dm("user123", message_with_special_chars) + + assert result is True + call_args = mock_post.call_args + assert call_args[1]["json"]["message"] == message_with_special_chars + + +@patch("requests.post") +def test_broadcast_message_with_long_message(mock_post, mock_api_url): + """Test broadcasting a long message""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + long_message = "A" * 2000 + result = discord.broadcast_message("general", long_message) + + assert result is True + call_args = mock_post.call_args + assert call_args[1]["json"]["message"] == long_message + + +@patch("requests.get") +def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url): + """Test health check when response doesn't have status key""" + mock_response = Mock() + mock_response.json.return_value = {} + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + result = discord.is_collector_healthy() + + assert result is False diff --git a/tests/memory/discord/test_collector.py b/tests/memory/discord/test_collector.py new file mode 100644 index 0000000..b3624e8 --- /dev/null +++ b/tests/memory/discord/test_collector.py @@ -0,0 +1,840 @@ +import pytest +from datetime import datetime, timezone +from unittest.mock import Mock, patch, AsyncMock, MagicMock + +import discord + +from memory.discord.collector import ( + create_or_update_server, + determine_channel_metadata, + create_or_update_channel, + create_or_update_user, + determine_message_metadata, + should_track_message, + should_collect_bot_message, + sync_guild_metadata, + MessageCollector, +) +from memory.common.db.models.sources import ( + DiscordServer, + DiscordChannel, + DiscordUser, +) + + +# Fixtures for Discord objects +@pytest.fixture +def mock_guild(): + """Mock Discord Guild object""" + guild = Mock(spec=discord.Guild) + guild.id = 123456789 + guild.name = "Test Server" + guild.description = "A test server" + guild.member_count = 42 + return guild + + +@pytest.fixture +def mock_text_channel(): + """Mock Discord TextChannel object""" + channel = Mock(spec=discord.TextChannel) + channel.id = 987654321 + channel.name = "general" + guild = Mock() + guild.id = 123456789 + channel.guild = guild + return channel + + +@pytest.fixture +def mock_dm_channel(): + """Mock Discord DMChannel object""" + channel = Mock(spec=discord.DMChannel) + channel.id = 111222333 + recipient = Mock() + recipient.name = "TestUser" + channel.recipient = recipient + return channel + + +@pytest.fixture +def mock_user(): + """Mock Discord User object""" + user = Mock(spec=discord.User) + user.id = 444555666 + user.name = "testuser" + user.display_name = "Test User" + user.bot = False + return user + + +@pytest.fixture +def mock_message(mock_text_channel, mock_user): + """Mock Discord Message object""" + message = Mock(spec=discord.Message) + message.id = 777888999 + message.channel = mock_text_channel + message.author = mock_user + message.guild = mock_text_channel.guild + message.content = "Test message" + message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + message.reference = None + return message + + +# Tests for create_or_update_server +def test_create_or_update_server_creates_new(db_session, mock_guild): + """Test creating a new server record""" + result = create_or_update_server(db_session, mock_guild) + + assert result is not None + assert result.id == mock_guild.id + assert result.name == mock_guild.name + assert result.description == mock_guild.description + assert result.member_count == mock_guild.member_count + + +def test_create_or_update_server_updates_existing(db_session, mock_guild): + """Test updating an existing server record""" + # Create initial server + server = DiscordServer( + id=mock_guild.id, + name="Old Name", + description="Old Description", + member_count=10, + ) + db_session.add(server) + db_session.commit() + + # Update with new data + mock_guild.name = "New Name" + mock_guild.description = "New Description" + mock_guild.member_count = 50 + + result = create_or_update_server(db_session, mock_guild) + + assert result.name == "New Name" + assert result.description == "New Description" + assert result.member_count == 50 + assert result.last_sync_at is not None + + +def test_create_or_update_server_none_guild(db_session): + """Test with None guild""" + result = create_or_update_server(db_session, None) + assert result is None + + +# Tests for determine_channel_metadata +def test_determine_channel_metadata_dm(): + """Test metadata for DM channel""" + channel = Mock(spec=discord.DMChannel) + channel.recipient = Mock() + channel.recipient.name = "TestUser" + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert channel_type == "dm" + assert server_id is None + assert "DM with TestUser" in name + + +def test_determine_channel_metadata_dm_no_recipient(): + """Test metadata for DM channel without recipient""" + channel = Mock(spec=discord.DMChannel) + channel.recipient = None + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert channel_type == "dm" + assert name == "Unknown DM" + + +def test_determine_channel_metadata_group_dm(): + """Test metadata for group DM channel""" + channel = Mock(spec=discord.GroupChannel) + channel.name = "Group Chat" + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert channel_type == "group_dm" + assert server_id is None + assert name == "Group Chat" + + +def test_determine_channel_metadata_group_dm_no_name(): + """Test metadata for group DM without name""" + channel = Mock(spec=discord.GroupChannel) + channel.name = None + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert name == "Group DM" + + +def test_determine_channel_metadata_text_channel(): + """Test metadata for text channel""" + channel = Mock(spec=discord.TextChannel) + channel.name = "general" + channel.guild = Mock() + channel.guild.id = 123 + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert channel_type == "text" + assert server_id == 123 + assert name == "general" + + +def test_determine_channel_metadata_voice_channel(): + """Test metadata for voice channel""" + channel = Mock(spec=discord.VoiceChannel) + channel.name = "voice-chat" + channel.guild = Mock() + channel.guild.id = 456 + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert channel_type == "voice" + assert server_id == 456 + assert name == "voice-chat" + + +def test_determine_channel_metadata_thread(): + """Test metadata for thread""" + channel = Mock(spec=discord.Thread) + channel.name = "thread-1" + channel.guild = Mock() + channel.guild.id = 789 + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert channel_type == "thread" + assert server_id == 789 + assert name == "thread-1" + + +def test_determine_channel_metadata_unknown(): + """Test metadata for unknown channel type""" + channel = Mock() + channel.id = 999 + # Ensure the mock doesn't have a 'name' attribute + del channel.name + + channel_type, server_id, name = determine_channel_metadata(channel) + + assert channel_type == "unknown" + assert name == "Unknown-999" + + +# Tests for create_or_update_channel +def test_create_or_update_channel_creates_new( + db_session, mock_text_channel, mock_guild +): + """Test creating a new channel record""" + # Create the server first to satisfy foreign key constraint + create_or_update_server(db_session, mock_guild) + + result = create_or_update_channel(db_session, mock_text_channel) + + assert result is not None + assert result.id == mock_text_channel.id + assert result.name == mock_text_channel.name + assert result.channel_type == "text" + + +def test_create_or_update_channel_updates_existing(db_session, mock_text_channel): + """Test updating an existing channel record""" + # Create initial channel + channel = DiscordChannel( + id=mock_text_channel.id, + name="old-name", + channel_type="text", + ) + db_session.add(channel) + db_session.commit() + + # Update with new name + mock_text_channel.name = "new-name" + + result = create_or_update_channel(db_session, mock_text_channel) + + assert result.name == "new-name" + + +def test_create_or_update_channel_none_channel(db_session): + """Test with None channel""" + result = create_or_update_channel(db_session, None) + assert result is None + + +# Tests for create_or_update_user +def test_create_or_update_user_creates_new(db_session, mock_user): + """Test creating a new user record""" + result = create_or_update_user(db_session, mock_user) + + assert result is not None + assert result.id == mock_user.id + assert result.username == mock_user.name + assert result.display_name == mock_user.display_name + + +def test_create_or_update_user_updates_existing(db_session, mock_user): + """Test updating an existing user record""" + # Create initial user + user = DiscordUser( + id=mock_user.id, + username="oldname", + display_name="Old Display Name", + ) + db_session.add(user) + db_session.commit() + + # Update with new data + mock_user.name = "newname" + mock_user.display_name = "New Display Name" + + result = create_or_update_user(db_session, mock_user) + + assert result.username == "newname" + assert result.display_name == "New Display Name" + + +def test_create_or_update_user_none_user(db_session): + """Test with None user""" + result = create_or_update_user(db_session, None) + assert result is None + + +# Tests for determine_message_metadata +def test_determine_message_metadata_default(): + """Test metadata for default message""" + message = Mock() + message.reference = None + message.channel = Mock() + # Ensure channel doesn't have parent attribute + del message.channel.parent + + message_type, reply_to_id, thread_id = determine_message_metadata(message) + + assert message_type == "default" + assert reply_to_id is None + assert thread_id is None + + +def test_determine_message_metadata_reply(): + """Test metadata for reply message""" + message = Mock() + message.reference = Mock() + message.reference.message_id = 123456 + message.channel = Mock() + + message_type, reply_to_id, thread_id = determine_message_metadata(message) + + assert message_type == "reply" + assert reply_to_id == 123456 + + +def test_determine_message_metadata_thread(): + """Test metadata for message in thread""" + message = Mock() + message.reference = None + message.channel = Mock() + message.channel.id = 999 + message.channel.parent = Mock() # Has parent means it's a thread + + message_type, reply_to_id, thread_id = determine_message_metadata(message) + + assert thread_id == 999 + + +# Tests for should_track_message +def test_should_track_message_server_disabled(db_session): + """Test when server has tracking disabled""" + server = DiscordServer(id=1, name="Server", track_messages=False) + channel = DiscordChannel(id=2, name="Channel", channel_type="text") + user = DiscordUser(id=3, username="User") + + result = should_track_message(server, channel, user) + + assert result is False + + +def test_should_track_message_channel_disabled(db_session): + """Test when channel has tracking disabled""" + server = DiscordServer(id=1, name="Server", track_messages=True) + channel = DiscordChannel( + id=2, name="Channel", channel_type="text", track_messages=False + ) + user = DiscordUser(id=3, username="User") + + result = should_track_message(server, channel, user) + + assert result is False + + +def test_should_track_message_dm_allowed(db_session): + """Test DM tracking when user allows it""" + channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True) + user = DiscordUser(id=3, username="User", track_messages=True) + + result = should_track_message(None, channel, user) + + assert result is True + + +def test_should_track_message_dm_not_allowed(db_session): + """Test DM tracking when user doesn't allow it""" + channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True) + user = DiscordUser(id=3, username="User", track_messages=False) + + result = should_track_message(None, channel, user) + + assert result is False + + +def test_should_track_message_default_true(db_session): + """Test default tracking behavior""" + server = DiscordServer(id=1, name="Server", track_messages=True) + channel = DiscordChannel( + id=2, name="Channel", channel_type="text", track_messages=True + ) + user = DiscordUser(id=3, username="User") + + result = should_track_message(server, channel, user) + + assert result is True + + +# Tests for should_collect_bot_message +@patch("memory.common.settings.DISCORD_COLLECT_BOTS", False) +def test_should_collect_bot_message_bot_not_allowed(): + """Test bot message collection when disabled""" + message = Mock() + message.author = Mock() + message.author.bot = True + + result = should_collect_bot_message(message) + + assert result is False + + +@patch("memory.common.settings.DISCORD_COLLECT_BOTS", True) +def test_should_collect_bot_message_bot_allowed(): + """Test bot message collection when enabled""" + message = Mock() + message.author = Mock() + message.author.bot = True + + result = should_collect_bot_message(message) + + assert result is True + + +def test_should_collect_bot_message_human(): + """Test human message collection""" + message = Mock() + message.author = Mock() + message.author.bot = False + + result = should_collect_bot_message(message) + + assert result is True + + +# Tests for sync_guild_metadata +@patch("memory.discord.collector.make_session") +def test_sync_guild_metadata(mock_make_session, mock_guild): + """Test syncing guild metadata""" + mock_session = Mock() + mock_make_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_make_session.return_value.__exit__ = Mock(return_value=None) + + # Mock session.query().get() to return None (new server) + mock_session.query.return_value.get.return_value = None + + # Mock channels + text_channel = Mock(spec=discord.TextChannel) + text_channel.id = 1 + text_channel.name = "general" + text_channel.guild = mock_guild + + voice_channel = Mock(spec=discord.VoiceChannel) + voice_channel.id = 2 + voice_channel.name = "voice" + voice_channel.guild = mock_guild + + mock_guild.channels = [text_channel, voice_channel] + + sync_guild_metadata(mock_guild) + + # Verify session.commit was called + mock_session.commit.assert_called_once() + + +# Tests for MessageCollector class +def test_message_collector_init(): + """Test MessageCollector initialization""" + collector = MessageCollector() + + assert collector.command_prefix == "!memory_" + assert collector.help_command is None + assert collector.intents.message_content is True + assert collector.intents.guilds is True + assert collector.intents.members is True + assert collector.intents.dm_messages is True + + +@pytest.mark.asyncio +async def test_on_ready(): + """Test on_ready event handler""" + collector = MessageCollector() + collector.user = Mock() + collector.user.name = "TestBot" + collector.guilds = [Mock(), Mock()] + collector.sync_servers_and_channels = AsyncMock() + + await collector.on_ready() + + collector.sync_servers_and_channels.assert_called_once() + + +@pytest.mark.asyncio +@patch("memory.discord.collector.make_session") +@patch("memory.discord.collector.add_discord_message") +async def test_on_message_success(mock_add_task, mock_make_session, mock_message): + """Test successful message handling""" + mock_session = Mock() + mock_make_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_make_session.return_value.__exit__ = Mock(return_value=None) + mock_session.query.return_value.get.return_value = None # New entities + + collector = MessageCollector() + await collector.on_message(mock_message) + + # Verify task was queued + mock_add_task.delay.assert_called_once() + call_kwargs = mock_add_task.delay.call_args[1] + assert call_kwargs["message_id"] == mock_message.id + assert call_kwargs["channel_id"] == mock_message.channel.id + assert call_kwargs["author_id"] == mock_message.author.id + assert call_kwargs["content"] == mock_message.content + + +@pytest.mark.asyncio +@patch("memory.discord.collector.make_session") +async def test_on_message_bot_message_filtered(mock_make_session, mock_message): + """Test bot message filtering""" + mock_message.author.bot = True + + with patch( + "memory.discord.collector.should_collect_bot_message", return_value=False + ): + collector = MessageCollector() + await collector.on_message(mock_message) + + # Should not create session or queue task + mock_make_session.assert_not_called() + + +@pytest.mark.asyncio +@patch("memory.discord.collector.make_session") +async def test_on_message_error_handling(mock_make_session, mock_message): + """Test error handling in on_message""" + mock_make_session.side_effect = Exception("Database error") + + collector = MessageCollector() + # Should not raise + await collector.on_message(mock_message) + + +@pytest.mark.asyncio +@patch("memory.discord.collector.edit_discord_message") +async def test_on_message_edit(mock_edit_task): + """Test message edit handler""" + before = Mock() + after = Mock() + after.id = 123 + after.content = "Edited content" + after.edited_at = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + collector = MessageCollector() + await collector.on_message_edit(before, after) + + mock_edit_task.delay.assert_called_once() + call_kwargs = mock_edit_task.delay.call_args[1] + assert call_kwargs["message_id"] == 123 + assert call_kwargs["content"] == "Edited content" + + +@pytest.mark.asyncio +async def test_on_message_edit_error_handling(): + """Test error handling in on_message_edit""" + before = Mock() + after = Mock() + after.id = 123 + after.content = "Edited" + after.edited_at = None # Will trigger datetime.now + + with patch("memory.discord.collector.edit_discord_message") as mock_edit: + mock_edit.delay.side_effect = Exception("Task error") + + collector = MessageCollector() + # Should not raise + await collector.on_message_edit(before, after) + + +@pytest.mark.asyncio +async def test_sync_servers_and_channels(): + """Test syncing servers and channels""" + guild1 = Mock() + guild2 = Mock() + + collector = MessageCollector() + collector.guilds = [guild1, guild2] + + with patch("memory.discord.collector.sync_guild_metadata") as mock_sync: + await collector.sync_servers_and_channels() + + assert mock_sync.call_count == 2 + mock_sync.assert_any_call(guild1) + mock_sync.assert_any_call(guild2) + + +@pytest.mark.asyncio +@patch("memory.discord.collector.make_session") +async def test_refresh_metadata(mock_make_session): + """Test metadata refresh""" + mock_session = Mock() + mock_make_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_make_session.return_value.__exit__ = Mock(return_value=None) + mock_session.query.return_value.get.return_value = None + + guild = Mock() + guild.id = 123 + guild.name = "Test" + guild.channels = [] + guild.members = [] + + collector = MessageCollector() + collector.guilds = [guild] + collector.intents = Mock() + collector.intents.members = False + + result = await collector.refresh_metadata() + + assert result["servers_updated"] == 1 + assert result["channels_updated"] == 0 + assert result["users_updated"] == 0 + + +@pytest.mark.asyncio +async def test_get_user_by_id(): + """Test getting user by ID""" + user = Mock() + user.id = 123 + + collector = MessageCollector() + collector.get_user = Mock(return_value=user) + + result = await collector.get_user(123) + + assert result == user + + +@pytest.mark.asyncio +async def test_get_user_by_username(): + """Test getting user by username""" + member = Mock() + member.name = "testuser" + member.display_name = "Test User" + member.discriminator = "1234" + + guild = Mock() + guild.members = [member] + + collector = MessageCollector() + collector.guilds = [guild] + + result = await collector.get_user("testuser") + + assert result == member + + +@pytest.mark.asyncio +async def test_get_user_not_found(): + """Test getting non-existent user""" + collector = MessageCollector() + collector.guilds = [] + + with patch.object(collector, "get_user", return_value=None): + with patch.object( + collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock()) + ): + result = await collector.get_user(999) + assert result is None + + +@pytest.mark.asyncio +async def test_get_channel_by_name(): + """Test getting channel by name""" + channel = Mock(spec=discord.TextChannel) + channel.name = "general" + + guild = Mock() + guild.channels = [channel] + + collector = MessageCollector() + collector.guilds = [guild] + + result = await collector.get_channel_by_name("general") + + assert result == channel + + +@pytest.mark.asyncio +async def test_get_channel_by_name_not_found(): + """Test getting non-existent channel""" + guild = Mock() + guild.channels = [] + + collector = MessageCollector() + collector.guilds = [guild] + + result = await collector.get_channel_by_name("nonexistent") + + assert result is None + + +@pytest.mark.asyncio +async def test_create_channel(): + """Test creating a channel""" + guild = Mock() + guild.name = "Test Server" + new_channel = Mock() + guild.create_text_channel = AsyncMock(return_value=new_channel) + + collector = MessageCollector() + collector.get_guild = Mock(return_value=guild) + + result = await collector.create_channel("new-channel", guild_id=123) + + assert result == new_channel + guild.create_text_channel.assert_called_once_with("new-channel") + + +@pytest.mark.asyncio +async def test_create_channel_no_guild(): + """Test creating channel when no guild available""" + collector = MessageCollector() + collector.get_guild = Mock(return_value=None) + collector.guilds = [] + + result = await collector.create_channel("new-channel") + + assert result is None + + +@pytest.mark.asyncio +async def test_send_dm_success(): + """Test sending DM successfully""" + user = Mock() + user.send = AsyncMock() + + collector = MessageCollector() + collector.get_user = AsyncMock(return_value=user) + + result = await collector.send_dm(123, "Hello!") + + assert result is True + user.send.assert_called_once_with("Hello!") + + +@pytest.mark.asyncio +async def test_send_dm_user_not_found(): + """Test sending DM when user not found""" + collector = MessageCollector() + collector.get_user = AsyncMock(return_value=None) + + result = await collector.send_dm(123, "Hello!") + + assert result is False + + +@pytest.mark.asyncio +async def test_send_dm_exception(): + """Test sending DM with exception""" + user = Mock() + user.send = AsyncMock(side_effect=Exception("Send failed")) + + collector = MessageCollector() + collector.get_user = AsyncMock(return_value=user) + + result = await collector.send_dm(123, "Hello!") + + assert result is False + + +@pytest.mark.asyncio +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +async def test_send_to_channel_success(): + """Test sending to channel successfully""" + channel = Mock() + channel.send = AsyncMock() + + collector = MessageCollector() + collector.get_channel_by_name = AsyncMock(return_value=channel) + + result = await collector.send_to_channel("general", "Announcement!") + + assert result is True + channel.send.assert_called_once_with("Announcement!") + + +@pytest.mark.asyncio +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) +async def test_send_to_channel_notifications_disabled(): + """Test sending to channel when notifications disabled""" + collector = MessageCollector() + + result = await collector.send_to_channel("general", "Announcement!") + + assert result is False + + +@pytest.mark.asyncio +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +async def test_send_to_channel_not_found(): + """Test sending to non-existent channel""" + collector = MessageCollector() + collector.get_channel_by_name = AsyncMock(return_value=None) + + result = await collector.send_to_channel("nonexistent", "Message") + + assert result is False + + +@pytest.mark.asyncio +@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token") +async def test_run_collector(): + """Test running the collector""" + from memory.discord.collector import run_collector + + with patch("memory.discord.collector.MessageCollector") as mock_collector_class: + mock_collector = Mock() + mock_collector.start = AsyncMock() + mock_collector_class.return_value = mock_collector + + await run_collector() + + mock_collector.start.assert_called_once_with("test_token") + + +@pytest.mark.asyncio +@patch("memory.common.settings.DISCORD_BOT_TOKEN", None) +async def test_run_collector_no_token(): + """Test running collector without token""" + from memory.discord.collector import run_collector + + # Should return early without raising + await run_collector() diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py new file mode 100644 index 0000000..5ce5765 --- /dev/null +++ b/tests/memory/workers/tasks/test_discord_tasks.py @@ -0,0 +1,607 @@ +import pytest +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +from memory.common.db.models import ( + DiscordMessage, + DiscordUser, + DiscordServer, + DiscordChannel, +) +from memory.workers.tasks import discord + + +@pytest.fixture +def mock_discord_user(db_session): + """Create a Discord user for testing.""" + user = DiscordUser( + id=123456789, + username="testuser", + ignore_messages=False, + ) + db_session.add(user) + db_session.commit() + return user + + +@pytest.fixture +def mock_discord_server(db_session): + """Create a Discord server for testing.""" + server = DiscordServer( + id=987654321, + name="Test Server", + ignore_messages=False, + ) + db_session.add(server) + db_session.commit() + return server + + +@pytest.fixture +def mock_discord_channel(db_session, mock_discord_server): + """Create a Discord channel for testing.""" + channel = DiscordChannel( + id=111222333, + name="test-channel", + channel_type="text", + server_id=mock_discord_server.id, + ignore_messages=False, + ) + db_session.add(channel) + db_session.commit() + return channel + + +@pytest.fixture +def sample_message_data(mock_discord_user, mock_discord_channel): + """Sample message data for testing.""" + return { + "message_id": 999888777, + "channel_id": mock_discord_channel.id, + "author_id": mock_discord_user.id, + "content": "This is a test Discord message with enough content to be processed.", + "sent_at": "2024-01-01T12:00:00Z", + "server_id": None, + "message_reference_id": None, + } + + +def test_get_prev_returns_previous_messages( + db_session, mock_discord_user, mock_discord_channel +): + """Test that get_prev returns previous messages in order.""" + # Create previous messages + msg1 = DiscordMessage( + message_id=1, + channel_id=mock_discord_channel.id, + discord_user_id=mock_discord_user.id, + content="First message", + sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc), + modality="text", + sha256=b"hash1" + bytes(26), + ) + msg2 = DiscordMessage( + message_id=2, + channel_id=mock_discord_channel.id, + discord_user_id=mock_discord_user.id, + content="Second message", + sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc), + modality="text", + sha256=b"hash2" + bytes(26), + ) + msg3 = DiscordMessage( + message_id=3, + channel_id=mock_discord_channel.id, + discord_user_id=mock_discord_user.id, + content="Third message", + sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc), + modality="text", + sha256=b"hash3" + bytes(26), + ) + db_session.add_all([msg1, msg2, msg3]) + db_session.commit() + + # Get previous messages before 10:15 + result = discord.get_prev( + db_session, + mock_discord_channel.id, + datetime(2024, 1, 1, 10, 15, 0, tzinfo=timezone.utc), + ) + + assert len(result) == 3 + assert result[0] == "testuser: First message" + assert result[1] == "testuser: Second message" + assert result[2] == "testuser: Third message" + + +def test_get_prev_limits_context_window( + db_session, mock_discord_user, mock_discord_channel +): + """Test that get_prev respects DISCORD_CONTEXT_WINDOW setting.""" + # Create 15 messages (more than the default context window of 10) + for i in range(15): + msg = DiscordMessage( + message_id=i, + channel_id=mock_discord_channel.id, + discord_user_id=mock_discord_user.id, + content=f"Message {i}", + sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc), + modality="text", + sha256=f"hash{i}".encode() + bytes(27), + ) + db_session.add(msg) + db_session.commit() + + result = discord.get_prev( + db_session, + mock_discord_channel.id, + datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc), + ) + + # Should only return last 10 messages + assert len(result) == 10 + assert result[0] == "testuser: Message 5" # Oldest in window + assert result[-1] == "testuser: Message 14" # Most recent + + +def test_get_prev_empty_channel(db_session, mock_discord_channel): + """Test get_prev with no previous messages.""" + result = discord.get_prev( + db_session, + mock_discord_channel.id, + datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc), + ) + + assert result == [] + + +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_should_process_normal_message( + db_session, mock_discord_user, mock_discord_server, mock_discord_channel +): + """Test should_process returns True for normal messages.""" + message = DiscordMessage( + message_id=1, + channel_id=mock_discord_channel.id, + discord_user_id=mock_discord_user.id, + server_id=mock_discord_server.id, + content="Test", + sent_at=datetime.now(timezone.utc), + modality="text", + sha256=b"hash" + bytes(27), + ) + db_session.add(message) + db_session.commit() + db_session.refresh(message) + + assert discord.should_process(message) is True + + +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False) +def test_should_process_disabled(): + """Test should_process returns False when processing is disabled.""" + message = Mock() + assert discord.should_process(message) is False + + +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) +def test_should_process_notifications_disabled(): + """Test should_process returns False when notifications are disabled.""" + message = Mock() + assert discord.should_process(message) is False + + +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_should_process_server_ignored( + db_session, mock_discord_user, mock_discord_channel +): + """Test should_process returns False when server has ignore_messages=True.""" + server = DiscordServer( + id=123, + name="Ignored Server", + ignore_messages=True, + ) + db_session.add(server) + db_session.commit() + + message = DiscordMessage( + message_id=1, + channel_id=mock_discord_channel.id, + discord_user_id=mock_discord_user.id, + server_id=server.id, + content="Test", + sent_at=datetime.now(timezone.utc), + modality="text", + sha256=b"hash" + bytes(27), + ) + db_session.add(message) + db_session.commit() + db_session.refresh(message) + + assert discord.should_process(message) is False + + +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_should_process_channel_ignored( + db_session, mock_discord_user, mock_discord_server +): + """Test should_process returns False when channel has ignore_messages=True.""" + channel = DiscordChannel( + id=456, + name="ignored-channel", + channel_type="text", + server_id=mock_discord_server.id, + ignore_messages=True, + ) + db_session.add(channel) + db_session.commit() + + message = DiscordMessage( + message_id=1, + channel_id=channel.id, + discord_user_id=mock_discord_user.id, + server_id=mock_discord_server.id, + content="Test", + sent_at=datetime.now(timezone.utc), + modality="text", + sha256=b"hash" + bytes(27), + ) + db_session.add(message) + db_session.commit() + db_session.refresh(message) + + assert discord.should_process(message) is False + + +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_should_process_user_ignored( + db_session, mock_discord_server, mock_discord_channel +): + """Test should_process returns False when user has ignore_messages=True.""" + user = DiscordUser( + id=789, + username="ignoreduser", + ignore_messages=True, + ) + db_session.add(user) + db_session.commit() + + message = DiscordMessage( + message_id=1, + channel_id=mock_discord_channel.id, + discord_user_id=user.id, + server_id=mock_discord_server.id, + content="Test", + sent_at=datetime.now(timezone.utc), + modality="text", + sha256=b"hash" + bytes(27), + ) + db_session.add(message) + db_session.commit() + db_session.refresh(message) + + assert discord.should_process(message) is False + + +def test_add_discord_message_success(db_session, sample_message_data, qdrant): + """Test successful Discord message addition.""" + result = discord.add_discord_message(**sample_message_data) + + assert result["status"] == "processed" + assert "discordmessage_id" in result + + # Verify the message was created in the database + message = ( + db_session.query(DiscordMessage) + .filter_by(message_id=sample_message_data["message_id"]) + .first() + ) + assert message is not None + assert message.content == sample_message_data["content"] + assert message.message_type == "default" + assert message.reply_to_message_id is None + + +def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant): + """Test adding a Discord message that is a reply.""" + sample_message_data["message_reference_id"] = 111222333 + + discord.add_discord_message(**sample_message_data) + + message = ( + db_session.query(DiscordMessage) + .filter_by(message_id=sample_message_data["message_id"]) + .first() + ) + assert message.message_type == "reply" + assert message.reply_to_message_id == 111222333 + + +def test_add_discord_message_already_exists(db_session, sample_message_data, qdrant): + """Test adding a message that already exists.""" + # Add the message once + discord.add_discord_message(**sample_message_data) + + # Try to add it again + result = discord.add_discord_message(**sample_message_data) + + assert result["status"] == "already_exists" + assert result["message_id"] == sample_message_data["message_id"] + + # Verify no duplicate was created + messages = ( + db_session.query(DiscordMessage) + .filter_by(message_id=sample_message_data["message_id"]) + .all() + ) + assert len(messages) == 1 + + +def test_add_discord_message_with_context( + db_session, sample_message_data, mock_discord_user, qdrant +): + """Test that message is added successfully when previous messages exist.""" + # Add a previous message + prev_msg = DiscordMessage( + message_id=111111, + channel_id=sample_message_data["channel_id"], + discord_user_id=mock_discord_user.id, + content="Previous message", + sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc), + modality="text", + sha256=b"prev" + bytes(28), + ) + db_session.add(prev_msg) + db_session.commit() + + result = discord.add_discord_message(**sample_message_data) + + message = ( + db_session.query(DiscordMessage) + .filter_by(message_id=sample_message_data["message_id"]) + .first() + ) + assert message is not None + assert result["status"] == "processed" + + +@patch("memory.workers.tasks.discord.process_discord_message") +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_add_discord_message_triggers_processing( + mock_process, + db_session, + sample_message_data, + mock_discord_server, + mock_discord_channel, + qdrant, +): + """Test that add_discord_message triggers process_discord_message when conditions are met.""" + mock_process.delay = Mock() + sample_message_data["server_id"] = mock_discord_server.id + + discord.add_discord_message(**sample_message_data) + + # Verify process_discord_message.delay was called + mock_process.delay.assert_called_once() + + +@patch("memory.workers.tasks.discord.process_discord_message") +@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False) +def test_add_discord_message_no_processing_when_disabled( + mock_process, db_session, sample_message_data, qdrant +): + """Test that process_discord_message is not called when processing is disabled.""" + mock_process.delay = Mock() + + discord.add_discord_message(**sample_message_data) + + mock_process.delay.assert_not_called() + + +def test_edit_discord_message_success(db_session, sample_message_data, qdrant): + """Test successful Discord message edit.""" + # First add the message + discord.add_discord_message(**sample_message_data) + + # Edit it + new_content = ( + "This is the edited content with enough text to be meaningful and processed." + ) + edited_at = "2024-01-01T13:00:00Z" + + result = discord.edit_discord_message( + sample_message_data["message_id"], + new_content, + edited_at, + ) + + assert result["status"] == "processed" + + # Verify the message was updated + message = ( + db_session.query(DiscordMessage) + .filter_by(message_id=sample_message_data["message_id"]) + .first() + ) + assert message.content == new_content + assert message.edited_at is not None + + +def test_edit_discord_message_not_found(db_session): + """Test editing a message that doesn't exist.""" + result = discord.edit_discord_message( + 999999, + "New content", + "2024-01-01T13:00:00Z", + ) + + assert result["status"] == "error" + assert result["error"] == "Message not found" + assert result["message_id"] == 999999 + + +def test_edit_discord_message_updates_context( + db_session, sample_message_data, mock_discord_user, qdrant +): + """Test that editing a message works correctly.""" + # Add previous message and the message to be edited + prev_msg = DiscordMessage( + message_id=111111, + channel_id=sample_message_data["channel_id"], + discord_user_id=mock_discord_user.id, + content="Previous message", + sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc), + modality="text", + sha256=b"prev" + bytes(28), + ) + db_session.add(prev_msg) + db_session.commit() + + discord.add_discord_message(**sample_message_data) + + # Edit the message + result = discord.edit_discord_message( + sample_message_data["message_id"], + "Edited content that should have context updated properly.", + "2024-01-01T13:00:00Z", + ) + + # Verify message was updated + message = ( + db_session.query(DiscordMessage) + .filter_by(message_id=sample_message_data["message_id"]) + .first() + ) + assert ( + message.content == "Edited content that should have context updated properly." + ) + assert result["status"] == "processed" + + +def test_process_discord_message_success(db_session, sample_message_data, qdrant): + """Test processing a Discord message.""" + # Add a message first + add_result = discord.add_discord_message(**sample_message_data) + message_id = add_result["discordmessage_id"] + + # Process it + result = discord.process_discord_message(message_id) + + assert result["status"] == "processed" + assert result["message_id"] == message_id + + +def test_process_discord_message_not_found(db_session): + """Test processing a message that doesn't exist.""" + result = discord.process_discord_message(999999) + + assert result["status"] == "error" + assert result["error"] == "Message not found" + assert result["message_id"] == 999999 + + +@pytest.mark.parametrize( + "sent_at_str,expected_hour", + [ + ("2024-01-01T12:00:00Z", 12), + ("2024-01-01T00:00:00+00:00", 0), + ("2024-01-01T23:59:59Z", 23), + ], +) +def test_add_discord_message_datetime_parsing( + db_session, sample_message_data, sent_at_str, expected_hour, qdrant +): + """Test that various datetime formats are parsed correctly.""" + sample_message_data["sent_at"] = sent_at_str + + discord.add_discord_message(**sample_message_data) + + message = ( + db_session.query(DiscordMessage) + .filter_by(message_id=sample_message_data["message_id"]) + .first() + ) + assert message.sent_at.hour == expected_hour + + +def test_add_discord_message_unique_hash(db_session, sample_message_data, qdrant): + """Test that message hash includes message_id for uniqueness.""" + # Add first message + discord.add_discord_message(**sample_message_data) + + # Try to add another message with same content but different message_id + sample_message_data["message_id"] = 888777666 + + result = discord.add_discord_message(**sample_message_data) + + # Should succeed because hash includes message_id + assert result["status"] == "processed" + + # Verify both messages exist + messages = ( + db_session.query(DiscordMessage) + .filter_by(content=sample_message_data["content"]) + .all() + ) + assert len(messages) == 2 + + +def test_get_prev_only_returns_messages_from_same_channel( + db_session, mock_discord_user, mock_discord_server +): + """Test that get_prev only returns messages from the specified channel.""" + # Create two channels + channel1 = DiscordChannel( + id=111, + name="channel-1", + channel_type="text", + server_id=mock_discord_server.id, + ) + channel2 = DiscordChannel( + id=222, + name="channel-2", + channel_type="text", + server_id=mock_discord_server.id, + ) + db_session.add_all([channel1, channel2]) + db_session.commit() + + # Add messages to both channels + msg1 = DiscordMessage( + message_id=1, + channel_id=channel1.id, + discord_user_id=mock_discord_user.id, + content="Message in channel 1", + sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc), + modality="text", + sha256=b"hash1" + bytes(26), + ) + msg2 = DiscordMessage( + message_id=2, + channel_id=channel2.id, + discord_user_id=mock_discord_user.id, + content="Message in channel 2", + sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc), + modality="text", + sha256=b"hash2" + bytes(26), + ) + db_session.add_all([msg1, msg2]) + db_session.commit() + + # Get previous messages for channel 1 + result = discord.get_prev( + db_session, + channel1.id, # type: ignore + datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc), + ) + + # Should only return message from channel 1 + assert len(result) == 1 + assert "Message in channel 1" in result[0] + assert "Message in channel 2" not in result[0]