diff --git a/db/migrations/versions/20251012_222827_add_discord_models.py b/db/migrations/versions/20251012_222827_add_discord_models.py
index 91d661b..deb73e1 100644
--- a/db/migrations/versions/20251012_222827_add_discord_models.py
+++ b/db/migrations/versions/20251012_222827_add_discord_models.py
@@ -26,9 +26,12 @@ def upgrade() -> None:
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("member_count", sa.Integer(), nullable=True),
+ sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
- "track_messages", sa.Boolean(), server_default="true", nullable=False
+ "ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
+ sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
+ sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
@@ -56,7 +59,12 @@ def upgrade() -> None:
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("channel_type", sa.Text(), nullable=False),
- sa.Column("track_messages", sa.Boolean(), nullable=True),
+ sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
+ sa.Column(
+ "ignore_messages", sa.Boolean(), server_default="false", nullable=True
+ ),
+ sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
+ sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
@@ -84,9 +92,12 @@ def upgrade() -> None:
sa.Column("username", sa.Text(), nullable=False),
sa.Column("display_name", sa.Text(), nullable=True),
sa.Column("system_user_id", sa.Integer(), nullable=True),
+ sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
- "allow_dm_tracking", sa.Boolean(), server_default="true", nullable=False
+ "ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
+ sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
+ sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt
index 8ecc0b7..f88aec2 100644
--- a/requirements/requirements-common.txt
+++ b/requirements/requirements-common.txt
@@ -5,7 +5,7 @@ alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
-anthropic==0.18.1
+anthropic==0.69.0
openai==1.25.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0
diff --git a/src/memory/api/search/scorer.py b/src/memory/api/search/scorer.py
index d978f98..5049d95 100644
--- a/src/memory/api/search/scorer.py
+++ b/src/memory/api/search/scorer.py
@@ -40,7 +40,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
try:
response = await asyncio.to_thread(
- llms.call,
+ llms.summarize,
prompt,
settings.RANKER_MODEL,
images=images,
diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py
index 75a3861..cc930da 100644
--- a/src/memory/common/celery_app.py
+++ b/src/memory/common/celery_app.py
@@ -15,6 +15,7 @@ SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
DISCORD_ROOT = "memory.workers.tasks.discord"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
+PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
diff --git a/src/memory/common/db/models/sources.py b/src/memory/common/db/models/sources.py
index 9d018e5..1cbb9e4 100644
--- a/src/memory/common/db/models/sources.py
+++ b/src/memory/common/db/models/sources.py
@@ -127,7 +127,15 @@ class EmailAccount(Base):
)
-class DiscordServer(Base):
+class MessageProcessor:
+ track_messages = Column(Boolean, nullable=False, server_default="true")
+ ignore_messages = Column(Boolean, nullable=True, default=False)
+
+ allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
+ disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
+
+
+class DiscordServer(Base, MessageProcessor):
"""Discord server configuration and metadata"""
__tablename__ = "discord_servers"
@@ -138,7 +146,6 @@ class DiscordServer(Base):
member_count = Column(Integer)
# Collection settings
- track_messages = Column(Boolean, nullable=False, server_default="true")
last_sync_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
@@ -152,7 +159,7 @@ class DiscordServer(Base):
)
-class DiscordChannel(Base):
+class DiscordChannel(Base, MessageProcessor):
"""Discord channel metadata and configuration"""
__tablename__ = "discord_channels"
@@ -163,7 +170,6 @@ class DiscordChannel(Base):
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
# Collection settings (null = inherit from server)
- track_messages = Column(Boolean, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
@@ -171,7 +177,7 @@ class DiscordChannel(Base):
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
-class DiscordUser(Base):
+class DiscordUser(Base, MessageProcessor):
"""Discord user metadata and preferences"""
__tablename__ = "discord_users"
@@ -184,7 +190,6 @@ class DiscordUser(Base):
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
# Basic DM settings
- allow_dm_tracking = Column(Boolean, nullable=False, server_default="true")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
diff --git a/src/memory/common/llms.py b/src/memory/common/llms.py
deleted file mode 100644
index 8b88b75..0000000
--- a/src/memory/common/llms.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import logging
-import base64
-import io
-from typing import Any
-from PIL import Image
-
-from memory.common import settings, tokens
-
-logger = logging.getLogger(__name__)
-
-SYSTEM_PROMPT = """
-You are a helpful assistant that creates concise summaries and identifies key topics.
-"""
-
-
-def encode_image(image: Image.Image) -> str:
- """Encode PIL Image to base64 string."""
- buffer = io.BytesIO()
- # Convert to RGB if necessary (for RGBA, etc.)
- if image.mode != "RGB":
- image = image.convert("RGB")
- image.save(buffer, format="JPEG")
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
-
-
-def call_openai(
- prompt: str,
- model: str,
- images: list[Image.Image] = [],
- system_prompt: str = SYSTEM_PROMPT,
-) -> str:
- """Call OpenAI API for summarization."""
- import openai
-
- client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
- try:
- user_content: Any = [{"type": "text", "text": prompt}]
- if images:
- for image in images:
- encoded_image = encode_image(image)
- user_content.append(
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
- }
- )
-
- response = client.chat.completions.create(
- model=model.split("/")[1],
- messages=[
- {
- "role": "system",
- "content": system_prompt,
- },
- {"role": "user", "content": user_content},
- ],
- temperature=0.3,
- max_tokens=2048,
- )
- return response.choices[0].message.content or ""
- except Exception as e:
- logger.error(f"OpenAI API error: {e}")
- raise
-
-
-def call_anthropic(
- prompt: str,
- model: str,
- images: list[Image.Image] = [],
- system_prompt: str = SYSTEM_PROMPT,
-) -> str:
- """Call Anthropic API for summarization."""
- import anthropic
-
- client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
- try:
- # Prepare the message content
- content: Any = [{"type": "text", "text": prompt}]
- if images:
- # Add images if provided
- for image in images:
- encoded_image = encode_image(image)
- content.append(
- { # type: ignore
- "type": "image",
- "source": {
- "type": "base64",
- "media_type": "image/jpeg",
- "data": encoded_image,
- },
- }
- )
-
- response = client.messages.create(
- model=model.split("/")[1],
- messages=[{"role": "user", "content": content}], # type: ignore
- system=system_prompt,
- temperature=0.3,
- max_tokens=2048,
- )
- return response.content[0].text
- except Exception as e:
- logger.error(f"Anthropic API error: {e}")
- raise
-
-
-def call(
- prompt: str,
- model: str,
- images: list[Image.Image] = [],
- system_prompt: str = SYSTEM_PROMPT,
-) -> str:
- if model.startswith("anthropic"):
- return call_anthropic(prompt, model, images, system_prompt)
- return call_openai(prompt, model, images, system_prompt)
-
-
-def truncate(content: str, target_tokens: int) -> str:
- target_chars = target_tokens * tokens.CHARS_PER_TOKEN
- if len(content) > target_chars:
- return content[:target_chars].rsplit(" ", 1)[0] + "..."
- return content
diff --git a/src/memory/common/llms/__init__.py b/src/memory/common/llms/__init__.py
new file mode 100644
index 0000000..68b6bd9
--- /dev/null
+++ b/src/memory/common/llms/__init__.py
@@ -0,0 +1,79 @@
+"""LLM provider module for unified LLM access."""
+
+# Legacy imports for backwards compatibility
+import logging
+
+from PIL import Image
+
+
+# New provider system
+from memory.common.llms.base import (
+ BaseLLMProvider,
+ ImageContent,
+ LLMSettings,
+ Message,
+ MessageContent,
+ MessageRole,
+ StreamEvent,
+ TextContent,
+ ThinkingContent,
+ ToolDefinition,
+ ToolResultContent,
+ ToolUseContent,
+ create_provider,
+)
+from memory.common import tokens
+
+__all__ = [
+ "BaseLLMProvider",
+ "Message",
+ "MessageRole",
+ "MessageContent",
+ "TextContent",
+ "ImageContent",
+ "ToolUseContent",
+ "ToolResultContent",
+ "ThinkingContent",
+ "ToolDefinition",
+ "StreamEvent",
+ "LLMSettings",
+ "create_provider",
+]
+
+logger = logging.getLogger(__name__)
+
+
+def summarize(
+ prompt: str,
+ model: str,
+ images: list[Image.Image] = [],
+ system_prompt: str = "",
+) -> str:
+ provider = create_provider(model=model)
+ try:
+ # Build message content
+ content: list[MessageContent] = [TextContent(text=prompt)]
+ for image in images:
+ content.append(ImageContent(image=image))
+
+ messages = [Message(role=MessageRole.USER, content=content)]
+ settings_obj = LLMSettings(temperature=0.3, max_tokens=2048)
+
+ res = provider.run_with_tools(
+ messages=messages,
+ system_prompt=system_prompt
+ or "You are a helpful assistant that creates concise summaries and identifies key topics.",
+ settings=settings_obj,
+ tools={},
+ )
+ return res.response or ""
+ except Exception as e:
+ logger.error(f"Anthropic API error: {e}")
+ raise
+
+
+def truncate(content: str, target_tokens: int) -> str:
+ target_chars = target_tokens * tokens.CHARS_PER_TOKEN
+ if len(content) > target_chars:
+ return content[:target_chars].rsplit(" ", 1)[0] + "..."
+ return content
diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py
new file mode 100644
index 0000000..a46236f
--- /dev/null
+++ b/src/memory/common/llms/anthropic_provider.py
@@ -0,0 +1,451 @@
+"""Anthropic LLM provider implementation."""
+
+import json
+import logging
+from typing import Any, AsyncIterator, Iterator, Optional
+
+import anthropic
+
+from memory.common.llms.base import (
+ BaseLLMProvider,
+ ImageContent,
+ LLMSettings,
+ Message,
+ MessageRole,
+ StreamEvent,
+ ToolDefinition,
+ ToolUseContent,
+ ThinkingContent,
+ TextContent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class AnthropicProvider(BaseLLMProvider):
+ """Anthropic LLM provider with streaming, tool support, and extended thinking."""
+
+ # Models that support extended thinking
+ THINKING_MODELS = {
+ "claude-opus-4",
+ "claude-opus-4-1",
+ "claude-sonnet-4-0",
+ "claude-sonnet-3-7",
+ "claude-sonnet-4-5",
+ }
+
+ def __init__(self, api_key: str, model: str, enable_thinking: bool = False):
+ """
+ Initialize the Anthropic provider.
+
+ Args:
+ api_key: Anthropic API key
+ model: Model identifier
+ enable_thinking: Enable extended thinking for supported models
+ """
+ super().__init__(api_key, model)
+ self.enable_thinking = enable_thinking
+ self._async_client: Optional[anthropic.AsyncAnthropic] = None
+
+ def _initialize_client(self) -> anthropic.Anthropic:
+ """Initialize the Anthropic client."""
+ return anthropic.Anthropic(api_key=self.api_key)
+
+ @property
+ def async_client(self) -> anthropic.AsyncAnthropic:
+ """Lazy-load the async client."""
+ if self._async_client is None:
+ self._async_client = anthropic.AsyncAnthropic(api_key=self.api_key)
+ return self._async_client
+
+ def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
+ """Convert ImageContent to Anthropic's base64 source format."""
+ encoded_image = self.encode_image(content.image)
+ return {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": encoded_image,
+ },
+ }
+
+ def _convert_message(self, message: Message) -> dict[str, Any]:
+ converted = message.to_dict()
+ if converted["role"] == MessageRole.ASSISTANT and isinstance(
+ converted["content"], list
+ ):
+ content = sorted(
+ converted["content"], key=lambda x: x["type"] != "thinking"
+ )
+ return converted | {"content": content}
+ return converted
+
+ def _should_include_message(self, message: Message) -> bool:
+ """Filter out system messages (handled separately in Anthropic)."""
+ return message.role != MessageRole.SYSTEM
+
+ def _supports_thinking(self) -> bool:
+ """Check if the current model supports extended thinking."""
+ model_lower = self.model.lower()
+ return any(supported in model_lower for supported in self.THINKING_MODELS)
+
+ def _build_request_kwargs(
+ self,
+ messages: list[Message],
+ system_prompt: str | None,
+ tools: list[ToolDefinition] | None,
+ settings: LLMSettings,
+ ) -> dict[str, Any]:
+ """Build common request kwargs for API calls."""
+ anthropic_messages = self._convert_messages(messages)
+
+ kwargs: dict[str, Any] = {
+ "model": self.model,
+ "messages": anthropic_messages,
+ "temperature": settings.temperature,
+ "max_tokens": settings.max_tokens,
+ }
+
+ # Only include top_p if explicitly set
+ if settings.top_p is not None:
+ kwargs["top_p"] = settings.top_p
+
+ if system_prompt:
+ kwargs["system"] = system_prompt
+
+ if settings.stop_sequences:
+ kwargs["stop_sequences"] = settings.stop_sequences
+
+ if tools:
+ kwargs["tools"] = self._convert_tools(tools)
+
+ # Enable extended thinking if requested and model supports it
+ if self.enable_thinking and self._supports_thinking():
+ thinking_budget = min(10000, settings.max_tokens - 1024)
+ if thinking_budget >= 1024:
+ kwargs["thinking"] = {
+ "type": "enabled",
+ "budget_tokens": thinking_budget,
+ }
+ # When thinking is enabled: temperature must be 1, can't use top_p
+ kwargs["temperature"] = 1.0
+ kwargs.pop("top_p", None)
+
+ return kwargs
+
+ def _handle_stream_event(
+ self, event: Any, current_tool_use: dict[str, Any] | None
+ ) -> tuple[StreamEvent | None, dict[str, Any] | None]:
+ """
+ Handle a streaming event and return StreamEvent and updated tool state.
+
+ Returns:
+ Tuple of (StreamEvent or None, updated current_tool_use or None)
+ """
+ event_type = getattr(event, "type", None)
+
+ # Handle error events
+ if event_type == "error":
+ error = getattr(event, "error", None)
+ error_msg = str(error) if error else "Unknown error"
+ return StreamEvent(type="error", data=error_msg), current_tool_use
+
+ if event_type == "content_block_start":
+ block = getattr(event, "content_block", None)
+ if not block:
+ return None, current_tool_use
+
+ block_type = getattr(block, "type", None)
+
+ # Handle various tool types (tool_use, mcp_tool_use, server_tool_use)
+ if block_type in ("tool_use", "mcp_tool_use", "server_tool_use"):
+ # In content_block_start, input may already be present (empty dict)
+ block_input = getattr(block, "input", None)
+ current_tool_use = {
+ "id": getattr(block, "id", ""),
+ "name": getattr(block, "name", ""),
+ "input": block_input if block_input is not None else "",
+ "server_name": getattr(block, "server_name", None),
+ "is_server_call": block_type != "tool_use",
+ }
+
+ # Handle tool result blocks
+ elif hasattr(block, "tool_use_id"):
+ tool_result = {
+ "id": getattr(block, "tool_use_id", ""),
+ "result": getattr(block, "content", ""),
+ }
+ return StreamEvent(
+ type="tool_result", data=tool_result
+ ), current_tool_use
+
+ # For non-tool blocks (text, thinking), we don't need to track state
+ return None, current_tool_use
+
+ elif event_type == "content_block_delta":
+ delta = getattr(event, "delta", None)
+ if not delta:
+ return None, current_tool_use
+
+ delta_type = getattr(delta, "type", None)
+
+ if delta_type == "text_delta":
+ text = getattr(delta, "text", "")
+ return StreamEvent(type="text", data=text), current_tool_use
+
+ elif delta_type == "thinking_delta":
+ thinking = getattr(delta, "thinking", "")
+ return StreamEvent(type="thinking", data=thinking), current_tool_use
+
+ elif delta_type == "signature_delta":
+ # Handle thinking signature for extended thinking
+ signature = getattr(delta, "signature", "")
+ return StreamEvent(
+ type="thinking", signature=signature
+ ), current_tool_use
+
+ elif delta_type == "input_json_delta":
+ if current_tool_use is None:
+ # Edge case: received input_json_delta without tool_use start
+ logger.warning("Received input_json_delta without tool_use context")
+ return None, None
+
+ # Only accumulate if input is still a string (being built up)
+ if isinstance(current_tool_use.get("input"), str):
+ partial_json = getattr(delta, "partial_json", "")
+ current_tool_use["input"] += partial_json
+ # else: input was already set as a dict in content_block_start
+
+ return None, current_tool_use
+
+ elif event_type == "content_block_stop":
+ if current_tool_use:
+ # Use the parsed input from the content block if available
+ # This handles empty inputs {} more reliably than parsing
+ content_block = getattr(event, "content_block", None)
+ if content_block and hasattr(content_block, "input"):
+ current_tool_use["input"] = content_block.input
+ else:
+ # Fallback: parse accumulated JSON string
+ input_str = current_tool_use.get("input", "")
+ if isinstance(input_str, str):
+ # Need to parse the accumulated string
+ if not input_str or input_str.isspace():
+ # Empty or whitespace-only input
+ current_tool_use["input"] = {}
+ else:
+ try:
+ current_tool_use["input"] = json.loads(input_str)
+ except json.JSONDecodeError as e:
+ logger.warning(
+ f"Failed to parse tool input '{input_str}': {e}"
+ )
+ current_tool_use["input"] = {}
+ # else: input is already parsed
+
+ tool_data = {
+ "id": current_tool_use.get("id", ""),
+ "name": current_tool_use.get("name", ""),
+ "input": current_tool_use.get("input", {}),
+ }
+ # Include server info if present
+ if current_tool_use.get("server_name"):
+ tool_data["server_name"] = current_tool_use["server_name"]
+ if current_tool_use.get("is_server_call"):
+ tool_data["is_server_call"] = current_tool_use["is_server_call"]
+
+ return StreamEvent(type="tool_use", data=tool_data), None
+
+ elif event_type == "message_delta":
+ delta = getattr(event, "delta", None)
+ if delta:
+ stop_reason = getattr(delta, "stop_reason", None)
+ if stop_reason == "max_tokens":
+ return StreamEvent(
+ type="error", data="Max tokens reached"
+ ), current_tool_use
+
+ # Handle token usage information
+ usage = getattr(event, "usage", None)
+ if usage:
+ usage_data = {
+ "input_tokens": getattr(usage, "input_tokens", 0),
+ "output_tokens": getattr(usage, "output_tokens", 0),
+ "cache_creation_input_tokens": getattr(
+ usage, "cache_creation_input_tokens", None
+ ),
+ "cache_read_input_tokens": getattr(
+ usage, "cache_read_input_tokens", None
+ ),
+ }
+ # Could emit this as a separate event type if needed
+ logger.debug(f"Token usage: {usage_data}")
+
+ return None, current_tool_use
+
+ elif event_type == "message_stop":
+ # Final event - clean up any pending state
+ if current_tool_use:
+ logger.warning(
+ f"Message ended with incomplete tool use: {current_tool_use}"
+ )
+ return StreamEvent(type="done"), None
+
+ # Unknown event type - log but don't fail
+ if event_type and event_type not in (
+ "message_start",
+ "message_delta",
+ "content_block_start",
+ "content_block_delta",
+ "content_block_stop",
+ "message_stop",
+ ):
+ logger.debug(f"Unknown event type: {event_type}")
+
+ return None, current_tool_use
+
+ def generate(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> str:
+ """Generate a non-streaming response."""
+ settings = settings or LLMSettings()
+ kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
+
+ try:
+ response = self.client.messages.create(**kwargs)
+ return "".join(
+ block.text for block in response.content if block.type == "text"
+ )
+ except Exception as e:
+ logger.error(f"Anthropic API error: {e}")
+ raise
+
+ def stream(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> Iterator[StreamEvent]:
+ """Generate a streaming response."""
+ settings = settings or LLMSettings()
+ kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
+
+ try:
+ with self.client.messages.stream(**kwargs) as stream:
+ current_tool_use: dict[str, Any] | None = None
+
+ for event in stream:
+ stream_event, current_tool_use = self._handle_stream_event(
+ event, current_tool_use
+ )
+ if stream_event:
+ yield stream_event
+
+ except Exception as e:
+ logger.error(f"Anthropic streaming error: {e}")
+ yield StreamEvent(type="error", data=str(e))
+
+ async def agenerate(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> str:
+ """Generate a non-streaming response asynchronously."""
+ settings = settings or LLMSettings()
+ kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
+
+ try:
+ response = await self.async_client.messages.create(**kwargs)
+ return "".join(
+ block.text for block in response.content if block.type == "text"
+ )
+ except Exception as e:
+ logger.error(f"Anthropic API error: {e}")
+ raise
+
+ async def astream(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> AsyncIterator[StreamEvent]:
+ """Generate a streaming response asynchronously."""
+ settings = settings or LLMSettings()
+ kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
+
+ try:
+ async with self.async_client.messages.stream(**kwargs) as stream:
+ current_tool_use: dict[str, Any] | None = None
+
+ async for event in stream:
+ stream_event, current_tool_use = self._handle_stream_event(
+ event, current_tool_use
+ )
+ if stream_event:
+ yield stream_event
+
+ except Exception as e:
+ logger.error(f"Anthropic streaming error: {e}")
+ yield StreamEvent(type="error", data=str(e))
+
+ def stream_with_tools(
+ self,
+ messages: list[Message],
+ tools: dict[str, ToolDefinition],
+ settings: LLMSettings | None = None,
+ system_prompt: str | None = None,
+ max_iterations: int = 10,
+ ) -> Iterator[StreamEvent]:
+ if max_iterations <= 0:
+ return
+
+ response = TextContent(text="")
+ thinking = ThinkingContent(thinking="", signature="")
+
+ for event in self.stream(
+ messages=messages,
+ system_prompt=system_prompt,
+ tools=list(tools.values()),
+ settings=settings,
+ ):
+ if event.type == "text":
+ response.text += event.data
+ yield event
+ elif event.type == "thinking" and event.signature:
+ thinking.signature = event.signature
+ elif event.type == "thinking":
+ thinking.thinking += event.data
+ yield event
+ elif event.type == "tool_use":
+ yield event
+ tool_result = self.execute_tool(event.data, tools)
+ yield StreamEvent(type="tool_result", data=tool_result.to_dict())
+ messages.append(
+ Message.assistant(
+ response,
+ thinking,
+ ToolUseContent(
+ id=event.data["id"],
+ name=event.data["name"],
+ input=event.data["input"],
+ ),
+ )
+ )
+ messages.append(Message.user(tool_result=tool_result))
+ yield from self.stream_with_tools(
+ messages, tools, settings, system_prompt, max_iterations - 1
+ )
+ elif event.type == "tool_result":
+ yield event
+ elif event.type == "error":
+ logger.error(f"LLM error: {event.data}")
+ raise RuntimeError(f"LLM error: {event.data}")
diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py
new file mode 100644
index 0000000..aceb7f8
--- /dev/null
+++ b/src/memory/common/llms/base.py
@@ -0,0 +1,561 @@
+"""Base classes and types for LLM providers."""
+
+import base64
+import io
+import logging
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
+
+from PIL import Image
+
+from memory.common import settings
+from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
+
+logger = logging.getLogger(__name__)
+
+
+class MessageRole(str, Enum):
+ """Message roles for chat history."""
+
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+ TOOL = "tool"
+
+
+@dataclass
+class TextContent:
+ """Text content in a message."""
+
+ type: Literal["text"] = "text"
+ text: str = ""
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary format."""
+ return {"type": "text", "text": self.text}
+
+
+@dataclass
+class ImageContent:
+ """Image content in a message."""
+
+ type: Literal["image"] = "image"
+ image: Image.Image = None # type: ignore
+ detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary format."""
+ # Note: Image will be encoded by provider-specific implementation
+ return {"type": "image", "image": self.image}
+
+
+@dataclass
+class ToolUseContent:
+ """Tool use request from the assistant."""
+
+ type: Literal["tool_use"] = "tool_use"
+ id: str = ""
+ name: str = ""
+ input: dict[str, Any] = None # type: ignore
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary format."""
+ return {
+ "type": "tool_use",
+ "id": self.id,
+ "name": self.name,
+ "input": self.input,
+ }
+
+
+@dataclass
+class ToolResultContent:
+ """Tool result from tool execution."""
+
+ type: Literal["tool_result"] = "tool_result"
+ tool_use_id: str = ""
+ content: str = ""
+ is_error: bool = False
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary format."""
+ return {
+ "type": "tool_result",
+ "tool_use_id": self.tool_use_id,
+ "content": self.content,
+ "is_error": self.is_error,
+ }
+
+
+@dataclass
+class ThinkingContent:
+ """Thinking/reasoning content from the assistant (extended thinking)."""
+
+ type: Literal["thinking"] = "thinking"
+ thinking: str = ""
+ signature: str | None = None
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary format."""
+ return {
+ "type": "thinking",
+ "thinking": self.thinking,
+ "signature": self.signature,
+ }
+
+
+MessageContent = Union[
+ TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
+]
+
+
+@dataclass
+class Turn:
+ """A turn in the conversation."""
+
+ response: str | None
+ thinking: str | None
+ tool_calls: dict[str, ToolResult] | None
+
+
+@dataclass
+class Message:
+ """A message in the conversation history."""
+
+ role: MessageRole
+ content: Union[str, list[MessageContent]]
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert message to dictionary format."""
+ if isinstance(self.content, str):
+ return {"role": self.role.value, "content": self.content}
+ content_list = [item.to_dict() for item in self.content]
+ return {"role": self.role.value, "content": content_list}
+
+ @staticmethod
+ def assistant(
+ text: TextContent | None = None,
+ thinking: ThinkingContent | None = None,
+ tool_use: ToolUseContent | None = None,
+ ) -> "Message":
+ parts = []
+ if text:
+ parts.append(text)
+ if thinking:
+ parts.append(thinking)
+ if tool_use:
+ parts.append(tool_use)
+ return Message(role=MessageRole.ASSISTANT, content=parts)
+
+ @staticmethod
+ def user(
+ text: str | None = None, tool_result: ToolResultContent | None = None
+ ) -> "Message":
+ parts = []
+ if text:
+ parts.append(TextContent(text=text))
+ if tool_result:
+ parts.append(tool_result)
+ return Message(role=MessageRole.USER, content=parts)
+
+
+@dataclass
+class StreamEvent:
+ """An event from the streaming response."""
+
+ type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
+ data: Any = None
+ signature: str | None = None
+
+
+@dataclass
+class LLMSettings:
+ """Settings for LLM API calls."""
+
+ temperature: float = 0.7
+ max_tokens: int = 2048
+ # Don't set by default - some models don't allow both temp and top_p
+ top_p: float | None = None
+ stop_sequences: list[str] | None = None
+ stream: bool = False
+
+
+class BaseLLMProvider(ABC):
+ """Base class for LLM providers."""
+
+ def __init__(self, api_key: str, model: str):
+ """
+ Initialize the LLM provider.
+
+ Args:
+ api_key: API key for the provider
+ model: Model identifier
+ """
+ self.api_key = api_key
+ self.model = model
+ self._client: Any = None
+
+ @abstractmethod
+ def _initialize_client(self) -> Any:
+ """Initialize the provider-specific client."""
+ pass
+
+ @property
+ def client(self) -> Any:
+ """Lazy-load the client."""
+ if self._client is None:
+ self._client = self._initialize_client()
+ return self._client
+
+ def execute_tool(
+ self,
+ tool_call: ToolCall,
+ tool_handlers: dict[str, ToolDefinition],
+ ) -> ToolResultContent:
+ """
+ Execute a tool call.
+
+ Args:
+ tool_call: Tool call
+ tool_handlers: Dict mapping tool names to handler functions
+
+ Returns:
+ ToolResultContent with result or error
+ """
+ name = tool_call.get("name")
+ tool_use_id = tool_call.get("id")
+ input = tool_call.get("input")
+
+ if not name:
+ return ToolResultContent(
+ tool_use_id=tool_use_id,
+ content="Tool name missing",
+ is_error=True,
+ )
+
+ if not (tool := tool_handlers.get(name)):
+ return ToolResultContent(
+ tool_use_id=tool_use_id,
+ content=f"Tool '{name}' not found",
+ is_error=True,
+ )
+
+ try:
+ return ToolResultContent(
+ tool_use_id=tool_use_id,
+ content=tool(input),
+ is_error=False,
+ )
+ except Exception as e:
+ logger.error(f"Tool '{name}' failed: {e}", exc_info=True)
+ return ToolResultContent(
+ tool_use_id=tool_use_id,
+ content=str(e),
+ is_error=True,
+ )
+
+ def encode_image(self, image: Image.Image) -> str:
+ """
+ Encode PIL Image to base64 string.
+
+ Args:
+ image: PIL Image to encode
+
+ Returns:
+ Base64 encoded string
+ """
+ buffer = io.BytesIO()
+ # Convert to RGB if necessary (for RGBA, etc.)
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+ image.save(buffer, format="JPEG")
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+ def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
+ """Convert TextContent to provider format. Override for custom format."""
+ return content.to_dict()
+
+ def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
+ """Convert ImageContent to provider format. Override for custom format."""
+ return content.to_dict()
+
+ def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
+ """Convert ToolUseContent to provider format. Override for custom format."""
+ return content.to_dict()
+
+ def _convert_tool_result_content(
+ self, content: ToolResultContent
+ ) -> dict[str, Any]:
+ """Convert ToolResultContent to provider format. Override for custom format."""
+ return content.to_dict()
+
+ def _convert_thinking_content(self, content: ThinkingContent) -> dict[str, Any]:
+ """Convert ThinkingContent to provider format. Override for custom format."""
+ return content.to_dict()
+
+ def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
+ """
+ Convert a MessageContent item to provider format.
+
+ Dispatches to type-specific converters that can be overridden.
+ """
+ if isinstance(content, TextContent):
+ return self._convert_text_content(content)
+ elif isinstance(content, ImageContent):
+ return self._convert_image_content(content)
+ elif isinstance(content, ToolUseContent):
+ return self._convert_tool_use_content(content)
+ elif isinstance(content, ToolResultContent):
+ return self._convert_tool_result_content(content)
+ elif isinstance(content, ThinkingContent):
+ return self._convert_thinking_content(content)
+ else:
+ raise ValueError(f"Unknown content type: {type(content)}")
+
+ def _convert_message(self, message: Message) -> dict[str, Any]:
+ """
+ Convert a Message to provider format.
+
+ Can be overridden for provider-specific handling (e.g., filtering system messages).
+ """
+ return message.to_dict()
+
+ def _should_include_message(self, message: Message) -> bool:
+ """
+ Determine if a message should be included in the request.
+
+ Override to filter messages (e.g., Anthropic filters SYSTEM messages).
+
+ Args:
+ message: Message to check
+
+ Returns:
+ True if message should be included
+ """
+ return True
+
+ def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
+ """
+ Convert a list of messages to provider format.
+
+ Uses _should_include_message for filtering and _convert_message for conversion.
+ """
+ return [
+ self._convert_message(msg)
+ for msg in messages
+ if self._should_include_message(msg)
+ ]
+
+ def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
+ """
+ Convert a single ToolDefinition to provider format.
+
+ Default format matches Anthropic. Override for other providers (e.g., OpenAI uses functions).
+ """
+ return {
+ "name": tool.name,
+ "description": tool.description,
+ "input_schema": tool.input_schema,
+ }
+
+ def _convert_tools(
+ self, tools: list[ToolDefinition] | None
+ ) -> Optional[list[dict[str, Any]]]:
+ """Convert tool definitions to provider format."""
+ if not tools:
+ return None
+ return [self._convert_tool(tool) for tool in tools]
+
+ @abstractmethod
+ def generate(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> str:
+ """
+ Generate a non-streaming response.
+
+ Args:
+ messages: Conversation history
+ system_prompt: Optional system prompt
+ tools: Optional list of tools the LLM can use
+ settings: Optional settings for the generation
+
+ Returns:
+ Generated text response
+ """
+ pass
+
+ @abstractmethod
+ def stream(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> Iterator[StreamEvent]:
+ """
+ Generate a streaming response.
+
+ Args:
+ messages: Conversation history
+ system_prompt: Optional system prompt
+ tools: Optional list of tools the LLM can use
+ settings: Optional settings for the generation
+
+ Yields:
+ StreamEvent objects containing text chunks, tool uses, or errors
+ """
+ pass
+
+ @abstractmethod
+ async def agenerate(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> str:
+ """
+ Generate a non-streaming response asynchronously.
+
+ Args:
+ messages: Conversation history
+ system_prompt: Optional system prompt
+ tools: Optional list of tools the LLM can use
+ settings: Optional settings for the generation
+
+ Returns:
+ Generated text response
+ """
+ pass
+
+ @abstractmethod
+ async def astream(
+ self,
+ messages: list[Message],
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
+ ) -> AsyncIterator[StreamEvent]:
+ """
+ Generate a streaming response asynchronously.
+
+ Args:
+ messages: Conversation history
+ system_prompt: Optional system prompt
+ tools: Optional list of tools the LLM can use
+ settings: Optional settings for the generation
+
+ Yields:
+ StreamEvent objects containing text chunks, tool uses, or errors
+ """
+ pass
+
+ @abstractmethod
+ def stream_with_tools(
+ self,
+ messages: list[Message],
+ tools: dict[str, ToolDefinition],
+ settings: LLMSettings | None = None,
+ system_prompt: str | None = None,
+ max_iterations: int = 10,
+ ) -> Iterator[StreamEvent]:
+ pass
+
+ def run_with_tools(
+ self,
+ messages: list[Message],
+ tools: dict[str, ToolDefinition],
+ settings: LLMSettings | None = None,
+ system_prompt: str | None = None,
+ max_iterations: int = 10,
+ ) -> Turn:
+ thinking, response, tool_calls = "", "", {}
+ for event in self.stream_with_tools(
+ messages=messages,
+ tools=tools,
+ settings=settings,
+ system_prompt=system_prompt,
+ max_iterations=max_iterations,
+ ):
+ if event.type == "thinking":
+ thinking += event.data
+ elif event.type == "tool_use":
+ tool_calls[event.data["id"]] = {
+ "name": event.data["name"],
+ "input": event.data["input"],
+ "output": "",
+ }
+ elif event.type == "text":
+ response += event.data
+ elif event.type == "tool_result":
+ current = tool_calls.get(event.data["tool_use_id"]) or {}
+ tool_calls[event.data["tool_use_id"]] = {
+ "name": event.data.get("name") or current.get("name"),
+ "input": event.data.get("input") or current.get("input"),
+ "output": event.data.get("content"),
+ }
+ return Turn(
+ thinking=thinking or None,
+ response=response or None,
+ tool_calls=tool_calls or None,
+ )
+
+
+def create_provider(
+ model: str | None = None,
+ api_key: str | None = None,
+ enable_thinking: bool = False,
+) -> BaseLLMProvider:
+ """
+ Create an LLM provider based on the model name.
+
+ Args:
+ model: Model identifier (e.g., "claude-3-opus-20240229", "gpt-4").
+ If not provided, uses SUMMARIZER_MODEL from settings.
+ api_key: Optional API key. If not provided, uses keys from settings.
+ enable_thinking: Enable extended thinking for supported models (Claude Opus 4+, Sonnet 4+, Sonnet 3.7)
+
+ Returns:
+ An initialized LLM provider
+
+ Raises:
+ ValueError: If the provider cannot be determined from the model name
+ """
+ # Use default model from settings if not provided
+ if model is None:
+ model = settings.SUMMARIZER_MODEL
+
+ provider, model = model.split("/", 1)
+
+ if provider == "anthropic":
+ # Anthropic models
+ if api_key is None:
+ api_key = settings.ANTHROPIC_API_KEY
+
+ if not api_key:
+ raise ValueError(
+ "ANTHROPIC_API_KEY not found in settings. "
+ "Please set it in your .env file."
+ )
+
+ from memory.common.llms.anthropic_provider import AnthropicProvider
+
+ return AnthropicProvider(
+ api_key=api_key, model=model, enable_thinking=enable_thinking
+ )
+
+ # Could add OpenAI support here in the future
+ # elif "gpt" in model_lower or model.startswith("openai"):
+ # ...
+
+ else:
+ raise ValueError(
+ f"Unknown provider for model: {model}. "
+ f"Supported providers: Anthropic (claude-*)"
+ )
diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py
new file mode 100644
index 0000000..2f230fb
--- /dev/null
+++ b/src/memory/common/llms/openai_provider.py
@@ -0,0 +1,388 @@
+"""OpenAI LLM provider implementation."""
+
+import logging
+from typing import Any, AsyncIterator, Iterator, Optional
+
+import openai
+
+from memory.common.llms.base import (
+ BaseLLMProvider,
+ ImageContent,
+ LLMSettings,
+ Message,
+ MessageContent,
+ MessageRole,
+ StreamEvent,
+ TextContent,
+ ThinkingContent,
+ ToolDefinition,
+ ToolResultContent,
+ ToolUseContent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIProvider(BaseLLMProvider):
+ """OpenAI LLM provider with streaming and tool support."""
+
+ def _initialize_client(self) -> openai.OpenAI:
+ """Initialize the OpenAI client."""
+ return openai.OpenAI(api_key=self.api_key)
+
+ def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
+ """
+ Convert our Message format to OpenAI format.
+
+ Args:
+ messages: List of messages in our format
+
+ Returns:
+ List of messages in OpenAI format
+ """
+ openai_messages = []
+
+ for msg in messages:
+ if isinstance(msg.content, str):
+ openai_messages.append({"role": msg.role.value, "content": msg.content})
+ else:
+ # Handle multi-part content
+ content_parts = []
+ for item in msg.content:
+ if isinstance(item, TextContent):
+ content_parts.append({"type": "text", "text": item.text})
+ elif isinstance(item, ImageContent):
+ encoded_image = self.encode_image(item.image)
+ image_part: dict[str, Any] = {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{encoded_image}"
+ },
+ }
+ if item.detail:
+ image_part["image_url"]["detail"] = item.detail
+ content_parts.append(image_part)
+ elif isinstance(item, ToolUseContent):
+ # OpenAI doesn't have tool_use in content, it's a separate field
+ # We'll handle this by adding a tool_calls field to the message
+ pass
+ elif isinstance(item, ToolResultContent):
+ # OpenAI handles tool results as separate "tool" role messages
+ openai_messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": item.tool_use_id,
+ "content": item.content,
+ }
+ )
+ continue
+ elif isinstance(item, ThinkingContent):
+ # OpenAI doesn't have native thinking support in most models
+ # We can add it as text with a special marker
+ content_parts.append(
+ {
+ "type": "text",
+ "text": f"[Thinking: {item.thinking}]",
+ }
+ )
+
+ # Check if this message has tool calls
+ tool_calls = [
+ item for item in msg.content if isinstance(item, ToolUseContent)
+ ]
+
+ message_dict: dict[str, Any] = {"role": msg.role.value}
+
+ if content_parts:
+ message_dict["content"] = content_parts
+
+ if tool_calls:
+ message_dict["tool_calls"] = [
+ {
+ "id": tc.id,
+ "type": "function",
+ "function": {"name": tc.name, "arguments": str(tc.input)},
+ }
+ for tc in tool_calls
+ ]
+
+ openai_messages.append(message_dict)
+
+ return openai_messages
+
+ def _convert_tools(
+ self, tools: Optional[list[ToolDefinition]]
+ ) -> Optional[list[dict[str, Any]]]:
+ """
+ Convert our tool definitions to OpenAI format.
+
+ Args:
+ tools: List of tool definitions
+
+ Returns:
+ List of tools in OpenAI format
+ """
+ if not tools:
+ return None
+
+ return [
+ {
+ "type": "function",
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": tool.input_schema,
+ },
+ }
+ for tool in tools
+ ]
+
+ def generate(
+ self,
+ messages: list[Message],
+ system_prompt: Optional[str] = None,
+ tools: Optional[list[ToolDefinition]] = None,
+ settings: Optional[LLMSettings] = None,
+ ) -> str:
+ """Generate a non-streaming response."""
+ settings = settings or LLMSettings()
+
+ openai_messages = self._convert_messages(messages)
+
+ # Add system prompt as first message if provided
+ if system_prompt:
+ openai_messages.insert(
+ 0, {"role": "system", "content": system_prompt}
+ )
+
+ kwargs: dict[str, Any] = {
+ "model": self.model,
+ "messages": openai_messages,
+ "temperature": settings.temperature,
+ "max_tokens": settings.max_tokens,
+ "top_p": settings.top_p,
+ }
+
+ if settings.stop_sequences:
+ kwargs["stop"] = settings.stop_sequences
+
+ if tools:
+ kwargs["tools"] = self._convert_tools(tools)
+ kwargs["tool_choice"] = "auto"
+
+ try:
+ response = self.client.chat.completions.create(**kwargs)
+ return response.choices[0].message.content or ""
+ except Exception as e:
+ logger.error(f"OpenAI API error: {e}")
+ raise
+
+ def stream(
+ self,
+ messages: list[Message],
+ system_prompt: Optional[str] = None,
+ tools: Optional[list[ToolDefinition]] = None,
+ settings: Optional[LLMSettings] = None,
+ ) -> Iterator[StreamEvent]:
+ """Generate a streaming response."""
+ settings = settings or LLMSettings()
+
+ openai_messages = self._convert_messages(messages)
+
+ # Add system prompt as first message if provided
+ if system_prompt:
+ openai_messages.insert(
+ 0, {"role": "system", "content": system_prompt}
+ )
+
+ kwargs: dict[str, Any] = {
+ "model": self.model,
+ "messages": openai_messages,
+ "temperature": settings.temperature,
+ "max_tokens": settings.max_tokens,
+ "top_p": settings.top_p,
+ "stream": True,
+ }
+
+ if settings.stop_sequences:
+ kwargs["stop"] = settings.stop_sequences
+
+ if tools:
+ kwargs["tools"] = self._convert_tools(tools)
+ kwargs["tool_choice"] = "auto"
+
+ try:
+ stream = self.client.chat.completions.create(**kwargs)
+
+ current_tool_call: Optional[dict[str, Any]] = None
+
+ for chunk in stream:
+ if not chunk.choices:
+ continue
+
+ delta = chunk.choices[0].delta
+
+ # Handle text content
+ if delta.content:
+ yield StreamEvent(type="text", data=delta.content)
+
+ # Handle tool calls
+ if delta.tool_calls:
+ for tool_call in delta.tool_calls:
+ if tool_call.id:
+ # New tool call starting
+ if current_tool_call:
+ # Yield the previous one
+ yield StreamEvent(
+ type="tool_use", data=current_tool_call
+ )
+ current_tool_call = {
+ "id": tool_call.id,
+ "name": tool_call.function.name or "",
+ "arguments": tool_call.function.arguments or "",
+ }
+ elif current_tool_call and tool_call.function.arguments:
+ # Continue building the current tool call
+ current_tool_call["arguments"] += (
+ tool_call.function.arguments
+ )
+
+ # Check if stream is finished
+ if chunk.choices[0].finish_reason:
+ if current_tool_call:
+ yield StreamEvent(type="tool_use", data=current_tool_call)
+ current_tool_call = None
+
+ yield StreamEvent(type="done")
+
+ except Exception as e:
+ logger.error(f"OpenAI streaming error: {e}")
+ yield StreamEvent(type="error", data=str(e))
+
+ async def agenerate(
+ self,
+ messages: list[Message],
+ system_prompt: Optional[str] = None,
+ tools: Optional[list[ToolDefinition]] = None,
+ settings: Optional[LLMSettings] = None,
+ ) -> str:
+ """Generate a non-streaming response asynchronously."""
+ settings = settings or LLMSettings()
+
+ # Use async client
+ async_client = openai.AsyncOpenAI(api_key=self.api_key)
+
+ openai_messages = self._convert_messages(messages)
+
+ # Add system prompt as first message if provided
+ if system_prompt:
+ openai_messages.insert(
+ 0, {"role": "system", "content": system_prompt}
+ )
+
+ kwargs: dict[str, Any] = {
+ "model": self.model,
+ "messages": openai_messages,
+ "temperature": settings.temperature,
+ "max_tokens": settings.max_tokens,
+ "top_p": settings.top_p,
+ }
+
+ if settings.stop_sequences:
+ kwargs["stop"] = settings.stop_sequences
+
+ if tools:
+ kwargs["tools"] = self._convert_tools(tools)
+ kwargs["tool_choice"] = "auto"
+
+ try:
+ response = await async_client.chat.completions.create(**kwargs)
+ return response.choices[0].message.content or ""
+ except Exception as e:
+ logger.error(f"OpenAI API error: {e}")
+ raise
+
+ async def astream(
+ self,
+ messages: list[Message],
+ system_prompt: Optional[str] = None,
+ tools: Optional[list[ToolDefinition]] = None,
+ settings: Optional[LLMSettings] = None,
+ ) -> AsyncIterator[StreamEvent]:
+ """Generate a streaming response asynchronously."""
+ settings = settings or LLMSettings()
+
+ # Use async client
+ async_client = openai.AsyncOpenAI(api_key=self.api_key)
+
+ openai_messages = self._convert_messages(messages)
+
+ # Add system prompt as first message if provided
+ if system_prompt:
+ openai_messages.insert(
+ 0, {"role": "system", "content": system_prompt}
+ )
+
+ kwargs: dict[str, Any] = {
+ "model": self.model,
+ "messages": openai_messages,
+ "temperature": settings.temperature,
+ "max_tokens": settings.max_tokens,
+ "top_p": settings.top_p,
+ "stream": True,
+ }
+
+ if settings.stop_sequences:
+ kwargs["stop"] = settings.stop_sequences
+
+ if tools:
+ kwargs["tools"] = self._convert_tools(tools)
+ kwargs["tool_choice"] = "auto"
+
+ try:
+ stream = await async_client.chat.completions.create(**kwargs)
+
+ current_tool_call: Optional[dict[str, Any]] = None
+
+ async for chunk in stream:
+ if not chunk.choices:
+ continue
+
+ delta = chunk.choices[0].delta
+
+ # Handle text content
+ if delta.content:
+ yield StreamEvent(type="text", data=delta.content)
+
+ # Handle tool calls
+ if delta.tool_calls:
+ for tool_call in delta.tool_calls:
+ if tool_call.id:
+ # New tool call starting
+ if current_tool_call:
+ # Yield the previous one
+ yield StreamEvent(
+ type="tool_use", data=current_tool_call
+ )
+ current_tool_call = {
+ "id": tool_call.id,
+ "name": tool_call.function.name or "",
+ "arguments": tool_call.function.arguments or "",
+ }
+ elif current_tool_call and tool_call.function.arguments:
+ # Continue building the current tool call
+ current_tool_call["arguments"] += (
+ tool_call.function.arguments
+ )
+
+ # Check if stream is finished
+ if chunk.choices[0].finish_reason:
+ if current_tool_call:
+ yield StreamEvent(type="tool_use", data=current_tool_call)
+ current_tool_call = None
+
+ yield StreamEvent(type="done")
+
+ except Exception as e:
+ logger.error(f"OpenAI streaming error: {e}")
+ yield StreamEvent(type="error", data=str(e))
diff --git a/src/memory/common/llms/tools/__init__.py b/src/memory/common/llms/tools/__init__.py
new file mode 100644
index 0000000..7a4e1c0
--- /dev/null
+++ b/src/memory/common/llms/tools/__init__.py
@@ -0,0 +1,36 @@
+from dataclasses import dataclass
+from typing import Any, Callable, TypedDict
+
+
+ToolInput = str | dict[str, Any] | None
+ToolHandler = Callable[[ToolInput], str]
+
+
+class ToolCall(TypedDict):
+ """A call to a tool."""
+
+ name: str
+ id: str
+ input: ToolInput
+
+
+class ToolResult(TypedDict):
+ """A result from a tool call."""
+
+ id: str
+ name: str
+ input: ToolInput
+ output: str
+
+
+@dataclass
+class ToolDefinition:
+ """Definition of a tool that can be called by the LLM."""
+
+ name: str
+ description: str
+ input_schema: dict[str, Any] # JSON Schema for the tool's parameters
+ function: ToolHandler
+
+ def __call__(self, input: ToolInput) -> str:
+ return self.function(input)
diff --git a/src/memory/common/llms/tools/ping.py b/src/memory/common/llms/tools/ping.py
new file mode 100644
index 0000000..5ec2ea6
--- /dev/null
+++ b/src/memory/common/llms/tools/ping.py
@@ -0,0 +1,42 @@
+"""Ping tool for testing LLM tool integration."""
+
+from memory.common.llms.tools import ToolDefinition, ToolInput
+
+
+def handle_ping_call(message: ToolInput = None) -> str:
+ """
+ Handle a ping tool call.
+
+ Args:
+ message: Optional message to include in response
+
+ Returns:
+ Response string
+ """
+ if message:
+ return f"pong: {message}"
+ return "pong"
+
+
+def get_ping_tool() -> ToolDefinition:
+ """
+ Get a ping tool definition for testing tool calls.
+
+ Returns a simple tool that takes no required parameters and can be used
+ to verify that tool calling is working correctly.
+ """
+ return ToolDefinition(
+ name="ping",
+ description="A simple test tool that returns 'pong'. Use this to verify tool calling is working.",
+ input_schema={
+ "type": "object",
+ "properties": {
+ "message": {
+ "type": "string",
+ "description": "Optional message to echo back",
+ }
+ },
+ "required": [],
+ },
+ function=handle_ping_call,
+ )
diff --git a/src/memory/common/messages.py b/src/memory/common/messages.py
new file mode 100644
index 0000000..1244de3
--- /dev/null
+++ b/src/memory/common/messages.py
@@ -0,0 +1,9 @@
+def process_message(
+ msg: str,
+ history: list[str],
+ model: str | None = None,
+ system_prompt: str | None = None,
+ allowed_tools: list[str] | None = None,
+ disallowed_tools: list[str] | None = None,
+) -> str:
+ return "asd"
diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py
index 793e698..daf98bb 100644
--- a/src/memory/common/settings.py
+++ b/src/memory/common/settings.py
@@ -172,6 +172,7 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)
+DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
# Discord collector settings
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
diff --git a/src/memory/common/summarizer.py b/src/memory/common/summarizer.py
index c2016f3..b2d2dbc 100644
--- a/src/memory/common/summarizer.py
+++ b/src/memory/common/summarizer.py
@@ -105,7 +105,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
try:
- response = llms.call(prompt, settings.SUMMARIZER_MODEL)
+ response = llms.summarize(prompt, settings.SUMMARIZER_MODEL)
result = parse_response(response)
summary = result.get("summary", "")
diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py
index f42ff76..f60a22f 100644
--- a/src/memory/discord/collector.py
+++ b/src/memory/discord/collector.py
@@ -160,7 +160,7 @@ def should_track_message(
return False
if channel.channel_type in ("dm", "group_dm"):
- return bool(user.allow_dm_tracking)
+ return bool(user.track_messages)
# Default: track the message
return True
diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py
index 764d4f2..093acd1 100644
--- a/src/memory/workers/tasks/discord.py
+++ b/src/memory/workers/tasks/discord.py
@@ -16,7 +16,11 @@ from memory.workers.tasks.content_processing import (
create_task_result,
process_content_item,
)
-from memory.common.celery_app import ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE
+from memory.common.celery_app import (
+ ADD_DISCORD_MESSAGE,
+ EDIT_DISCORD_MESSAGE,
+ PROCESS_DISCORD_MESSAGE,
+)
from memory.common import settings
from sqlalchemy.orm import Session, scoped_session
@@ -40,6 +44,41 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
+def should_process(message: DiscordMessage) -> bool:
+ return (
+ settings.DISCORD_PROCESS_MESSAGES
+ and settings.DISCORD_NOTIFICATIONS_ENABLED
+ and not (
+ (message.server and message.server.ignore_messages)
+ or (message.channel and message.channel.ignore_messages)
+ or (message.discord_user and message.discord_user.ignore_messages)
+ )
+ )
+
+
+@app.task(name=PROCESS_DISCORD_MESSAGE)
+@safe_task_execution
+def process_discord_message(message_id: int) -> dict[str, Any]:
+ logger.info(f"Processing Discord message {message_id}")
+
+ with make_session() as session:
+ discord_message = session.query(DiscordMessage).get(message_id)
+ if not discord_message:
+ logger.info(f"Discord message not found: {message_id}")
+ return {
+ "status": "error",
+ "error": "Message not found",
+ "message_id": message_id,
+ }
+
+ print("Processing message", discord_message.id, discord_message.content)
+
+ return {
+ "status": "processed",
+ "message_id": message_id,
+ }
+
+
@app.task(name=ADD_DISCORD_MESSAGE)
@safe_task_execution
def add_discord_message(
@@ -87,11 +126,8 @@ def add_discord_message(
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
result = process_content_item(discord_message, session)
-
- logger.info(
- f"Discord message ID after process_content_item: {discord_message.id}"
- )
- logger.info(f"Process result: {result}")
+ if should_process(discord_message):
+ process_discord_message.delay(discord_message.id)
return result
diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py
index 200cd2a..3248285 100644
--- a/src/memory/workers/tasks/scheduled_calls.py
+++ b/src/memory/workers/tasks/scheduled_calls.py
@@ -88,11 +88,10 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Make the LLM call
if scheduled_call.model:
- response = llms.call(
+ response = llms.summarize(
prompt=cast(str, scheduled_call.message),
model=cast(str, scheduled_call.model),
- system_prompt=cast(str, scheduled_call.system_prompt)
- or llms.SYSTEM_PROMPT,
+ system_prompt=cast(str, scheduled_call.system_prompt),
)
else:
response = cast(str, scheduled_call.message)
diff --git a/tests/conftest.py b/tests/conftest.py
index 92166cb..175eedf 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -273,6 +273,27 @@ def mock_anthropic_client():
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
client = mock_client()
client.messages = Mock()
+
+ # Mock stream as a context manager
+ mock_stream = Mock()
+ mock_stream.__enter__ = Mock(
+ return_value=Mock(
+ __iter__=lambda self: iter(
+ [
+ Mock(
+ type="content_block_delta",
+ delta=Mock(
+ type="text_delta",
+ text="test summarytag1tag2",
+ ),
+ )
+ ]
+ )
+ )
+ )
+ mock_stream.__exit__ = Mock(return_value=False)
+ client.messages.stream = Mock(return_value=mock_stream)
+
client.messages.create = Mock(
return_value=Mock(
content=[
diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py
index 6624b6d..3d46a0d 100644
--- a/tests/memory/common/test_discord.py
+++ b/tests/memory/common/test_discord.py
@@ -2,318 +2,250 @@ import pytest
from unittest.mock import Mock, patch
import requests
-from memory.common import discord, settings
+from memory.common import discord
@pytest.fixture
-def mock_session_request():
- with patch("requests.Session.request") as mock:
- yield mock
+def mock_api_url():
+ """Mock the API URL to avoid using actual settings"""
+ with patch(
+ "memory.common.discord.get_api_url", return_value="http://localhost:8000"
+ ):
+ yield
-@pytest.fixture
-def mock_get_channels_response():
- return [
- {"name": "memory-errors", "id": "error_channel_id"},
- {"name": "memory-activity", "id": "activity_channel_id"},
- {"name": "memory-discoveries", "id": "discovery_channel_id"},
- {"name": "memory-chat", "id": "chat_channel_id"},
- ]
+@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
+@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
+def test_get_api_url():
+ """Test API URL construction"""
+ assert discord.get_api_url() == "http://testhost:9999"
-def test_discord_server_init(mock_session_request, mock_get_channels_response):
- # Mock the channels API call
+@patch("requests.post")
+def test_send_dm_success(mock_post, mock_api_url):
+ """Test successful DM sending"""
mock_response = Mock()
- mock_response.json.return_value = mock_get_channels_response
+ mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
- mock_session_request.return_value = mock_response
+ mock_post.return_value = mock_response
- server = discord.DiscordServer("server123", "Test Server")
+ result = discord.send_dm("user123", "Hello!")
- assert server.server_id == "server123"
- assert server.server_name == "Test Server"
- assert hasattr(server, "channels")
-
-
-@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
-@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
-@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
-@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
-def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
- # Mock the channels API call
- mock_response = Mock()
- mock_response.json.return_value = mock_get_channels_response
- mock_response.raise_for_status.return_value = None
- mock_session_request.return_value = mock_response
-
- server = discord.DiscordServer("server123", "Test Server")
-
- assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
- assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
- assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
- assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
-
-
-@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
-def test_setup_channels_create_missing(mock_session_request):
- # Mock get channels (empty) and create channel calls
- get_response = Mock()
- get_response.json.return_value = []
- get_response.raise_for_status.return_value = None
-
- create_response = Mock()
- create_response.json.return_value = {"id": "new_channel_id"}
- create_response.raise_for_status.return_value = None
-
- mock_session_request.side_effect = [
- get_response,
- create_response,
- create_response,
- create_response,
- create_response,
- ]
-
- server = discord.DiscordServer("server123", "Test Server")
-
- assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
-
-
-def test_channel_properties():
- server = discord.DiscordServer.__new__(discord.DiscordServer)
- server.channels = {
- discord.ERROR_CHANNEL: "error_id",
- discord.ACTIVITY_CHANNEL: "activity_id",
- discord.DISCOVERY_CHANNEL: "discovery_id",
- discord.CHAT_CHANNEL: "chat_id",
- }
-
- assert server.error_channel == "error_id"
- assert server.activity_channel == "activity_id"
- assert server.discovery_channel == "discovery_id"
- assert server.chat_channel == "chat_id"
-
-
-def test_channel_id_exists():
- server = discord.DiscordServer.__new__(discord.DiscordServer)
- server.channels = {"test-channel": "channel123"}
-
- assert server.channel_id("test-channel") == "channel123"
-
-
-def test_channel_id_not_found():
- server = discord.DiscordServer.__new__(discord.DiscordServer)
- server.channels = {}
-
- with pytest.raises(ValueError, match="Channel nonexistent not found"):
- server.channel_id("nonexistent")
-
-
-def test_send_message(mock_session_request):
- mock_response = Mock()
- mock_response.raise_for_status.return_value = None
- mock_session_request.return_value = mock_response
-
- server = discord.DiscordServer.__new__(discord.DiscordServer)
-
- server.send_message("channel123", "Hello World")
-
- mock_session_request.assert_called_with(
- "POST",
- "https://discord.com/api/v10/channels/channel123/messages",
- data=None,
- json={"content": "Hello World"},
- headers={
- "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
- "Content-Type": "application/json",
- },
+ assert result is True
+ mock_post.assert_called_once_with(
+ "http://localhost:8000/send_dm",
+ json={"user_identifier": "user123", "message": "Hello!"},
+ timeout=10,
)
-def test_create_channel(mock_session_request):
+@patch("requests.post")
+def test_send_dm_api_failure(mock_post, mock_api_url):
+ """Test DM sending when API returns failure"""
mock_response = Mock()
- mock_response.json.return_value = {"id": "new_channel_id"}
+ mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
- mock_session_request.return_value = mock_response
+ mock_post.return_value = mock_response
- server = discord.DiscordServer.__new__(discord.DiscordServer)
- server.server_id = "server123"
+ result = discord.send_dm("user123", "Hello!")
- channel_id = server.create_channel("new-channel")
-
- assert channel_id == "new_channel_id"
- mock_session_request.assert_called_with(
- "POST",
- "https://discord.com/api/v10/guilds/server123/channels",
- data=None,
- json={"name": "new-channel", "type": 0},
- headers={
- "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
- "Content-Type": "application/json",
- },
- )
+ assert result is False
-def test_create_channel_custom_type(mock_session_request):
+@patch("requests.post")
+def test_send_dm_request_exception(mock_post, mock_api_url):
+ """Test DM sending when request raises exception"""
+ mock_post.side_effect = requests.RequestException("Network error")
+
+ result = discord.send_dm("user123", "Hello!")
+
+ assert result is False
+
+
+@patch("requests.post")
+def test_send_dm_http_error(mock_post, mock_api_url):
+ """Test DM sending when HTTP error occurs"""
mock_response = Mock()
- mock_response.json.return_value = {"id": "voice_channel_id"}
+ mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
+ mock_post.return_value = mock_response
+
+ result = discord.send_dm("user123", "Hello!")
+
+ assert result is False
+
+
+@patch("requests.post")
+def test_broadcast_message_success(mock_post, mock_api_url):
+ """Test successful channel message broadcast"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
- mock_session_request.return_value = mock_response
+ mock_post.return_value = mock_response
- server = discord.DiscordServer.__new__(discord.DiscordServer)
- server.server_id = "server123"
+ result = discord.broadcast_message("general", "Announcement!")
- channel_id = server.create_channel("voice-channel", channel_type=2)
-
- assert channel_id == "voice_channel_id"
- mock_session_request.assert_called_with(
- "POST",
- "https://discord.com/api/v10/guilds/server123/channels",
- data=None,
- json={"name": "voice-channel", "type": 2},
- headers={
- "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
- "Content-Type": "application/json",
- },
+ assert result is True
+ mock_post.assert_called_once_with(
+ "http://localhost:8000/send_channel",
+ json={"channel_name": "general", "message": "Announcement!"},
+ timeout=10,
)
-def test_str_representation():
- server = discord.DiscordServer.__new__(discord.DiscordServer)
- server.server_id = "server123"
- server.server_name = "Test Server"
+@patch("requests.post")
+def test_broadcast_message_failure(mock_post, mock_api_url):
+ """Test channel message broadcast failure"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": False}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
- assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
+ result = discord.broadcast_message("general", "Announcement!")
+
+ assert result is False
-@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
-def test_request_adds_headers(mock_session_request):
- server = discord.DiscordServer.__new__(discord.DiscordServer)
+@patch("requests.post")
+def test_broadcast_message_exception(mock_post, mock_api_url):
+ """Test channel message broadcast with exception"""
+ mock_post.side_effect = requests.Timeout("Request timeout")
- server.request("GET", "https://example.com", headers={"Custom": "header"})
+ result = discord.broadcast_message("general", "Announcement!")
- expected_headers = {
- "Custom": "header",
- "Authorization": "Bot test_token_123",
- "Content-Type": "application/json",
- }
- mock_session_request.assert_called_once_with(
- "GET", "https://example.com", headers=expected_headers
- )
+ assert result is False
-def test_channels_url():
- server = discord.DiscordServer.__new__(discord.DiscordServer)
- server.server_id = "server123"
-
- assert (
- server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
- )
-
-
-@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
-def test_get_bot_servers_success(mock_get):
+def test_is_collector_healthy_true(mock_get, mock_api_url):
+ """Test health check when collector is healthy"""
mock_response = Mock()
- mock_response.json.return_value = [
- {"id": "server1", "name": "Server 1"},
- {"id": "server2", "name": "Server 2"},
- ]
+ mock_response.json.return_value = {"status": "healthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
- servers = discord.get_bot_servers()
+ result = discord.is_collector_healthy()
- assert len(servers) == 2
- assert servers[0] == {"id": "server1", "name": "Server 1"}
- mock_get.assert_called_once_with(
- "https://discord.com/api/v10/users/@me/guilds",
- headers={"Authorization": "Bot test_token"},
- )
+ assert result is True
+ mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
-@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
-def test_get_bot_servers_no_token():
- assert discord.get_bot_servers() == []
-
-
-@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
-def test_get_bot_servers_exception(mock_get):
- mock_get.side_effect = requests.RequestException("API Error")
+def test_is_collector_healthy_false_status(mock_get, mock_api_url):
+ """Test health check when collector returns unhealthy status"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"status": "unhealthy"}
+ mock_response.raise_for_status.return_value = None
+ mock_get.return_value = mock_response
- servers = discord.get_bot_servers()
+ result = discord.is_collector_healthy()
- assert servers == []
+ assert result is False
-@patch("memory.common.discord.get_bot_servers")
-@patch("memory.common.discord.DiscordServer")
-def test_load_servers(mock_discord_server_class, mock_get_servers):
- mock_get_servers.return_value = [
- {"id": "server1", "name": "Server 1"},
- {"id": "server2", "name": "Server 2"},
- ]
+@patch("requests.get")
+def test_is_collector_healthy_exception(mock_get, mock_api_url):
+ """Test health check when request fails"""
+ mock_get.side_effect = requests.ConnectionError("Connection refused")
- discord.load_servers()
+ result = discord.is_collector_healthy()
- assert mock_discord_server_class.call_count == 2
- mock_discord_server_class.assert_any_call("server1", "Server 1")
- mock_discord_server_class.assert_any_call("server2", "Server 2")
+ assert result is False
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
-def test_broadcast_message():
- mock_server1 = Mock()
- mock_server2 = Mock()
- discord.servers = {"1": mock_server1, "2": mock_server2}
+@patch("requests.post")
+def test_refresh_discord_metadata_success(mock_post, mock_api_url):
+ """Test successful metadata refresh"""
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "servers": 5,
+ "channels": 20,
+ "users": 100,
+ }
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
- discord.broadcast_message("test-channel", "Hello")
+ result = discord.refresh_discord_metadata()
- mock_server1.send_message.assert_called_once_with(
- mock_server1.channel_id.return_value, "Hello"
- )
- mock_server2.send_message.assert_called_once_with(
- mock_server2.channel_id.return_value, "Hello"
+ assert result == {"servers": 5, "channels": 20, "users": 100}
+ mock_post.assert_called_once_with(
+ "http://localhost:8000/refresh_metadata", timeout=30
)
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
-def test_broadcast_message_disabled():
- mock_server = Mock()
- discord.servers = {"1": mock_server}
+@patch("requests.post")
+def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
+ """Test metadata refresh failure"""
+ mock_post.side_effect = requests.RequestException("Failed to connect")
- discord.broadcast_message("test-channel", "Hello")
+ result = discord.refresh_discord_metadata()
- mock_server.send_message.assert_not_called()
+ assert result is None
+
+
+@patch("requests.post")
+def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
+ """Test metadata refresh with HTTP error"""
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
+ mock_post.return_value = mock_response
+
+ result = discord.refresh_discord_metadata()
+
+ assert result is None
@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
def test_send_error_message(mock_broadcast):
- discord.send_error_message("Error occurred")
- mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
+ """Test sending error message to error channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_error_message("Something broke")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("errors", "Something broke")
@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
def test_send_activity_message(mock_broadcast):
- discord.send_activity_message("Activity update")
- mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
+ """Test sending activity message to activity channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_activity_message("User logged in")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("activity", "User logged in")
@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
def test_send_discovery_message(mock_broadcast):
- discord.send_discovery_message("Discovery made")
- mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
+ """Test sending discovery message to discovery channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_discovery_message("Found interesting pattern")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
def test_send_chat_message(mock_broadcast):
- discord.send_chat_message("Chat message")
- mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
+ """Test sending chat message to chat channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_chat_message("Hello from bot")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("chat", "Hello from bot")
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error):
+ """Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong")
mock_send_error.assert_called_once()
@@ -323,69 +255,181 @@ def test_notify_task_failure_basic(mock_send_error):
assert "**Error:** Something went wrong" in message
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_args(mock_send_error):
+ """Test task failure notification with arguments"""
discord.notify_task_failure(
"test_task",
- "Error message",
- task_args=("arg1", "arg2"),
- task_kwargs={"key": "value"},
+ "Error occurred",
+ task_args=("arg1", 42),
+ task_kwargs={"key": "value", "number": 123},
)
message = mock_send_error.call_args[0][0]
- assert "**Args:** `('arg1', 'arg2')`" in message
- assert "**Kwargs:** `{'key': 'value'}`" in message
+ assert "**Args:** `('arg1', 42)" in message
+ assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_traceback(mock_send_error):
- traceback = "Traceback (most recent call last):\n File ...\nError: Something"
+ """Test task failure notification with traceback"""
+ traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
- discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
+ discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
message = mock_send_error.call_args[0][0]
+
assert "**Traceback:**" in message
- assert traceback in message
+ assert "Exception: test" in message
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_error(mock_send_error):
- long_error = "x" * 600 # Longer than 500 char limit
+ """Test that long error messages are truncated"""
+ long_error = "x" * 600
discord.notify_task_failure("test_task", long_error)
message = mock_send_error.call_args[0][0]
- assert long_error[:500] in message
+
+ # Error should be truncated to 500 chars - check that the full 600 char string is not there
+ assert "**Error:** " + long_error[:500] in message
+ # The full 600-char error should not be present
+ error_section = message.split("**Error:** ")[1].split("\n")[0]
+ assert len(error_section) == 500
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
- long_traceback = "x" * 1000 # Longer than 800 char limit
+ """Test that long tracebacks are truncated"""
+ long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
message = mock_send_error.call_args[0][0]
+
+ # Traceback should show last 800 chars
assert long_traceback[-800:] in message
+ # The full 1000-char traceback should not be present
+ traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
+ assert len(traceback_section) == 800
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_truncates_long_args(mock_send_error):
+ """Test that long task arguments are truncated"""
+ long_args = ("x" * 300,)
+
+ discord.notify_task_failure("test_task", "Error", task_args=long_args)
+
+ message = mock_send_error.call_args[0][0]
+
+ # Args should be truncated to 200 chars
+ assert (
+ len(message.split("**Args:**")[1].split("\n")[0]) <= 210
+ ) # Some buffer for formatting
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
+ """Test that long task kwargs are truncated"""
+ long_kwargs = {"key": "x" * 300}
+
+ discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
+
+ message = mock_send_error.call_args[0][0]
+
+ # Kwargs should be truncated to 200 chars
+ assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error):
- discord.notify_task_failure("test_task", "Error message")
+ """Test that notifications are not sent when disabled"""
+ discord.notify_task_failure("test_task", "Error occurred")
+
mock_send_error.assert_not_called()
-@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
-def test_notify_task_failure_send_fails(mock_send_error):
- mock_send_error.side_effect = Exception("Discord API error")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_send_error_exception(mock_send_error):
+ """Test that exceptions in send_error_message don't propagate"""
+ mock_send_error.side_effect = Exception("Failed to send")
- # Should not raise, just log the error
- discord.notify_task_failure("test_task", "Error message")
+ # Should not raise
+ discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "function,channel_setting,message",
+ [
+ (discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
+ (discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
+ (discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
+ (discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
+ ],
+)
+@patch("memory.common.discord.broadcast_message")
+def test_convenience_functions_use_correct_channels(
+ mock_broadcast, function, channel_setting, message
+):
+ """Test that convenience functions use the correct channel settings"""
+ with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
+ function(message)
+ mock_broadcast.assert_called_once_with("test-channel", message)
+
+
+@patch("requests.post")
+def test_send_dm_with_special_characters(mock_post, mock_api_url):
+ """Test sending DM with special characters"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": True}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ message_with_special_chars = "Hello! 🎉 <@123> #general"
+ result = discord.send_dm("user123", message_with_special_chars)
+
+ assert result is True
+ call_args = mock_post.call_args
+ assert call_args[1]["json"]["message"] == message_with_special_chars
+
+
+@patch("requests.post")
+def test_broadcast_message_with_long_message(mock_post, mock_api_url):
+ """Test broadcasting a long message"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": True}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ long_message = "A" * 2000
+ result = discord.broadcast_message("general", long_message)
+
+ assert result is True
+ call_args = mock_post.call_args
+ assert call_args[1]["json"]["message"] == long_message
+
+
+@patch("requests.get")
+def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
+ """Test health check when response doesn't have status key"""
+ mock_response = Mock()
+ mock_response.json.return_value = {}
+ mock_response.raise_for_status.return_value = None
+ mock_get.return_value = mock_response
+
+ result = discord.is_collector_healthy()
+
+ assert result is False
diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py
new file mode 100644
index 0000000..3d46a0d
--- /dev/null
+++ b/tests/memory/common/test_discord_integration.py
@@ -0,0 +1,435 @@
+import pytest
+from unittest.mock import Mock, patch
+import requests
+
+from memory.common import discord
+
+
+@pytest.fixture
+def mock_api_url():
+ """Mock the API URL to avoid using actual settings"""
+ with patch(
+ "memory.common.discord.get_api_url", return_value="http://localhost:8000"
+ ):
+ yield
+
+
+@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
+@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
+def test_get_api_url():
+ """Test API URL construction"""
+ assert discord.get_api_url() == "http://testhost:9999"
+
+
+@patch("requests.post")
+def test_send_dm_success(mock_post, mock_api_url):
+ """Test successful DM sending"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": True}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ result = discord.send_dm("user123", "Hello!")
+
+ assert result is True
+ mock_post.assert_called_once_with(
+ "http://localhost:8000/send_dm",
+ json={"user_identifier": "user123", "message": "Hello!"},
+ timeout=10,
+ )
+
+
+@patch("requests.post")
+def test_send_dm_api_failure(mock_post, mock_api_url):
+ """Test DM sending when API returns failure"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": False}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ result = discord.send_dm("user123", "Hello!")
+
+ assert result is False
+
+
+@patch("requests.post")
+def test_send_dm_request_exception(mock_post, mock_api_url):
+ """Test DM sending when request raises exception"""
+ mock_post.side_effect = requests.RequestException("Network error")
+
+ result = discord.send_dm("user123", "Hello!")
+
+ assert result is False
+
+
+@patch("requests.post")
+def test_send_dm_http_error(mock_post, mock_api_url):
+ """Test DM sending when HTTP error occurs"""
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
+ mock_post.return_value = mock_response
+
+ result = discord.send_dm("user123", "Hello!")
+
+ assert result is False
+
+
+@patch("requests.post")
+def test_broadcast_message_success(mock_post, mock_api_url):
+ """Test successful channel message broadcast"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": True}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ result = discord.broadcast_message("general", "Announcement!")
+
+ assert result is True
+ mock_post.assert_called_once_with(
+ "http://localhost:8000/send_channel",
+ json={"channel_name": "general", "message": "Announcement!"},
+ timeout=10,
+ )
+
+
+@patch("requests.post")
+def test_broadcast_message_failure(mock_post, mock_api_url):
+ """Test channel message broadcast failure"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": False}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ result = discord.broadcast_message("general", "Announcement!")
+
+ assert result is False
+
+
+@patch("requests.post")
+def test_broadcast_message_exception(mock_post, mock_api_url):
+ """Test channel message broadcast with exception"""
+ mock_post.side_effect = requests.Timeout("Request timeout")
+
+ result = discord.broadcast_message("general", "Announcement!")
+
+ assert result is False
+
+
+@patch("requests.get")
+def test_is_collector_healthy_true(mock_get, mock_api_url):
+ """Test health check when collector is healthy"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"status": "healthy"}
+ mock_response.raise_for_status.return_value = None
+ mock_get.return_value = mock_response
+
+ result = discord.is_collector_healthy()
+
+ assert result is True
+ mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
+
+
+@patch("requests.get")
+def test_is_collector_healthy_false_status(mock_get, mock_api_url):
+ """Test health check when collector returns unhealthy status"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"status": "unhealthy"}
+ mock_response.raise_for_status.return_value = None
+ mock_get.return_value = mock_response
+
+ result = discord.is_collector_healthy()
+
+ assert result is False
+
+
+@patch("requests.get")
+def test_is_collector_healthy_exception(mock_get, mock_api_url):
+ """Test health check when request fails"""
+ mock_get.side_effect = requests.ConnectionError("Connection refused")
+
+ result = discord.is_collector_healthy()
+
+ assert result is False
+
+
+@patch("requests.post")
+def test_refresh_discord_metadata_success(mock_post, mock_api_url):
+ """Test successful metadata refresh"""
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "servers": 5,
+ "channels": 20,
+ "users": 100,
+ }
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ result = discord.refresh_discord_metadata()
+
+ assert result == {"servers": 5, "channels": 20, "users": 100}
+ mock_post.assert_called_once_with(
+ "http://localhost:8000/refresh_metadata", timeout=30
+ )
+
+
+@patch("requests.post")
+def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
+ """Test metadata refresh failure"""
+ mock_post.side_effect = requests.RequestException("Failed to connect")
+
+ result = discord.refresh_discord_metadata()
+
+ assert result is None
+
+
+@patch("requests.post")
+def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
+ """Test metadata refresh with HTTP error"""
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
+ mock_post.return_value = mock_response
+
+ result = discord.refresh_discord_metadata()
+
+ assert result is None
+
+
+@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
+def test_send_error_message(mock_broadcast):
+ """Test sending error message to error channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_error_message("Something broke")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("errors", "Something broke")
+
+
+@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
+def test_send_activity_message(mock_broadcast):
+ """Test sending activity message to activity channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_activity_message("User logged in")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("activity", "User logged in")
+
+
+@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
+def test_send_discovery_message(mock_broadcast):
+ """Test sending discovery message to discovery channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_discovery_message("Found interesting pattern")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
+
+
+@patch("memory.common.discord.broadcast_message")
+@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
+def test_send_chat_message(mock_broadcast):
+ """Test sending chat message to chat channel"""
+ mock_broadcast.return_value = True
+
+ result = discord.send_chat_message("Hello from bot")
+
+ assert result is True
+ mock_broadcast.assert_called_once_with("chat", "Hello from bot")
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_basic(mock_send_error):
+ """Test basic task failure notification"""
+ discord.notify_task_failure("test_task", "Something went wrong")
+
+ mock_send_error.assert_called_once()
+ message = mock_send_error.call_args[0][0]
+
+ assert "🚨 **Task Failed: test_task**" in message
+ assert "**Error:** Something went wrong" in message
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_with_args(mock_send_error):
+ """Test task failure notification with arguments"""
+ discord.notify_task_failure(
+ "test_task",
+ "Error occurred",
+ task_args=("arg1", 42),
+ task_kwargs={"key": "value", "number": 123},
+ )
+
+ message = mock_send_error.call_args[0][0]
+
+ assert "**Args:** `('arg1', 42)" in message
+ assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_with_traceback(mock_send_error):
+ """Test task failure notification with traceback"""
+ traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
+
+ discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
+
+ message = mock_send_error.call_args[0][0]
+
+ assert "**Traceback:**" in message
+ assert "Exception: test" in message
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_truncates_long_error(mock_send_error):
+ """Test that long error messages are truncated"""
+ long_error = "x" * 600
+
+ discord.notify_task_failure("test_task", long_error)
+
+ message = mock_send_error.call_args[0][0]
+
+ # Error should be truncated to 500 chars - check that the full 600 char string is not there
+ assert "**Error:** " + long_error[:500] in message
+ # The full 600-char error should not be present
+ error_section = message.split("**Error:** ")[1].split("\n")[0]
+ assert len(error_section) == 500
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_truncates_long_traceback(mock_send_error):
+ """Test that long tracebacks are truncated"""
+ long_traceback = "x" * 1000
+
+ discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
+
+ message = mock_send_error.call_args[0][0]
+
+ # Traceback should show last 800 chars
+ assert long_traceback[-800:] in message
+ # The full 1000-char traceback should not be present
+ traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
+ assert len(traceback_section) == 800
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_truncates_long_args(mock_send_error):
+ """Test that long task arguments are truncated"""
+ long_args = ("x" * 300,)
+
+ discord.notify_task_failure("test_task", "Error", task_args=long_args)
+
+ message = mock_send_error.call_args[0][0]
+
+ # Args should be truncated to 200 chars
+ assert (
+ len(message.split("**Args:**")[1].split("\n")[0]) <= 210
+ ) # Some buffer for formatting
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
+ """Test that long task kwargs are truncated"""
+ long_kwargs = {"key": "x" * 300}
+
+ discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
+
+ message = mock_send_error.call_args[0][0]
+
+ # Kwargs should be truncated to 200 chars
+ assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
+def test_notify_task_failure_disabled(mock_send_error):
+ """Test that notifications are not sent when disabled"""
+ discord.notify_task_failure("test_task", "Error occurred")
+
+ mock_send_error.assert_not_called()
+
+
+@patch("memory.common.discord.send_error_message")
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_notify_task_failure_send_error_exception(mock_send_error):
+ """Test that exceptions in send_error_message don't propagate"""
+ mock_send_error.side_effect = Exception("Failed to send")
+
+ # Should not raise
+ discord.notify_task_failure("test_task", "Error occurred")
+
+ mock_send_error.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "function,channel_setting,message",
+ [
+ (discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
+ (discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
+ (discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
+ (discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
+ ],
+)
+@patch("memory.common.discord.broadcast_message")
+def test_convenience_functions_use_correct_channels(
+ mock_broadcast, function, channel_setting, message
+):
+ """Test that convenience functions use the correct channel settings"""
+ with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
+ function(message)
+ mock_broadcast.assert_called_once_with("test-channel", message)
+
+
+@patch("requests.post")
+def test_send_dm_with_special_characters(mock_post, mock_api_url):
+ """Test sending DM with special characters"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": True}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ message_with_special_chars = "Hello! 🎉 <@123> #general"
+ result = discord.send_dm("user123", message_with_special_chars)
+
+ assert result is True
+ call_args = mock_post.call_args
+ assert call_args[1]["json"]["message"] == message_with_special_chars
+
+
+@patch("requests.post")
+def test_broadcast_message_with_long_message(mock_post, mock_api_url):
+ """Test broadcasting a long message"""
+ mock_response = Mock()
+ mock_response.json.return_value = {"success": True}
+ mock_response.raise_for_status.return_value = None
+ mock_post.return_value = mock_response
+
+ long_message = "A" * 2000
+ result = discord.broadcast_message("general", long_message)
+
+ assert result is True
+ call_args = mock_post.call_args
+ assert call_args[1]["json"]["message"] == long_message
+
+
+@patch("requests.get")
+def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
+ """Test health check when response doesn't have status key"""
+ mock_response = Mock()
+ mock_response.json.return_value = {}
+ mock_response.raise_for_status.return_value = None
+ mock_get.return_value = mock_response
+
+ result = discord.is_collector_healthy()
+
+ assert result is False
diff --git a/tests/memory/discord/test_collector.py b/tests/memory/discord/test_collector.py
new file mode 100644
index 0000000..b3624e8
--- /dev/null
+++ b/tests/memory/discord/test_collector.py
@@ -0,0 +1,840 @@
+import pytest
+from datetime import datetime, timezone
+from unittest.mock import Mock, patch, AsyncMock, MagicMock
+
+import discord
+
+from memory.discord.collector import (
+ create_or_update_server,
+ determine_channel_metadata,
+ create_or_update_channel,
+ create_or_update_user,
+ determine_message_metadata,
+ should_track_message,
+ should_collect_bot_message,
+ sync_guild_metadata,
+ MessageCollector,
+)
+from memory.common.db.models.sources import (
+ DiscordServer,
+ DiscordChannel,
+ DiscordUser,
+)
+
+
+# Fixtures for Discord objects
+@pytest.fixture
+def mock_guild():
+ """Mock Discord Guild object"""
+ guild = Mock(spec=discord.Guild)
+ guild.id = 123456789
+ guild.name = "Test Server"
+ guild.description = "A test server"
+ guild.member_count = 42
+ return guild
+
+
+@pytest.fixture
+def mock_text_channel():
+ """Mock Discord TextChannel object"""
+ channel = Mock(spec=discord.TextChannel)
+ channel.id = 987654321
+ channel.name = "general"
+ guild = Mock()
+ guild.id = 123456789
+ channel.guild = guild
+ return channel
+
+
+@pytest.fixture
+def mock_dm_channel():
+ """Mock Discord DMChannel object"""
+ channel = Mock(spec=discord.DMChannel)
+ channel.id = 111222333
+ recipient = Mock()
+ recipient.name = "TestUser"
+ channel.recipient = recipient
+ return channel
+
+
+@pytest.fixture
+def mock_user():
+ """Mock Discord User object"""
+ user = Mock(spec=discord.User)
+ user.id = 444555666
+ user.name = "testuser"
+ user.display_name = "Test User"
+ user.bot = False
+ return user
+
+
+@pytest.fixture
+def mock_message(mock_text_channel, mock_user):
+ """Mock Discord Message object"""
+ message = Mock(spec=discord.Message)
+ message.id = 777888999
+ message.channel = mock_text_channel
+ message.author = mock_user
+ message.guild = mock_text_channel.guild
+ message.content = "Test message"
+ message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
+ message.reference = None
+ return message
+
+
+# Tests for create_or_update_server
+def test_create_or_update_server_creates_new(db_session, mock_guild):
+ """Test creating a new server record"""
+ result = create_or_update_server(db_session, mock_guild)
+
+ assert result is not None
+ assert result.id == mock_guild.id
+ assert result.name == mock_guild.name
+ assert result.description == mock_guild.description
+ assert result.member_count == mock_guild.member_count
+
+
+def test_create_or_update_server_updates_existing(db_session, mock_guild):
+ """Test updating an existing server record"""
+ # Create initial server
+ server = DiscordServer(
+ id=mock_guild.id,
+ name="Old Name",
+ description="Old Description",
+ member_count=10,
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ # Update with new data
+ mock_guild.name = "New Name"
+ mock_guild.description = "New Description"
+ mock_guild.member_count = 50
+
+ result = create_or_update_server(db_session, mock_guild)
+
+ assert result.name == "New Name"
+ assert result.description == "New Description"
+ assert result.member_count == 50
+ assert result.last_sync_at is not None
+
+
+def test_create_or_update_server_none_guild(db_session):
+ """Test with None guild"""
+ result = create_or_update_server(db_session, None)
+ assert result is None
+
+
+# Tests for determine_channel_metadata
+def test_determine_channel_metadata_dm():
+ """Test metadata for DM channel"""
+ channel = Mock(spec=discord.DMChannel)
+ channel.recipient = Mock()
+ channel.recipient.name = "TestUser"
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert channel_type == "dm"
+ assert server_id is None
+ assert "DM with TestUser" in name
+
+
+def test_determine_channel_metadata_dm_no_recipient():
+ """Test metadata for DM channel without recipient"""
+ channel = Mock(spec=discord.DMChannel)
+ channel.recipient = None
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert channel_type == "dm"
+ assert name == "Unknown DM"
+
+
+def test_determine_channel_metadata_group_dm():
+ """Test metadata for group DM channel"""
+ channel = Mock(spec=discord.GroupChannel)
+ channel.name = "Group Chat"
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert channel_type == "group_dm"
+ assert server_id is None
+ assert name == "Group Chat"
+
+
+def test_determine_channel_metadata_group_dm_no_name():
+ """Test metadata for group DM without name"""
+ channel = Mock(spec=discord.GroupChannel)
+ channel.name = None
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert name == "Group DM"
+
+
+def test_determine_channel_metadata_text_channel():
+ """Test metadata for text channel"""
+ channel = Mock(spec=discord.TextChannel)
+ channel.name = "general"
+ channel.guild = Mock()
+ channel.guild.id = 123
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert channel_type == "text"
+ assert server_id == 123
+ assert name == "general"
+
+
+def test_determine_channel_metadata_voice_channel():
+ """Test metadata for voice channel"""
+ channel = Mock(spec=discord.VoiceChannel)
+ channel.name = "voice-chat"
+ channel.guild = Mock()
+ channel.guild.id = 456
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert channel_type == "voice"
+ assert server_id == 456
+ assert name == "voice-chat"
+
+
+def test_determine_channel_metadata_thread():
+ """Test metadata for thread"""
+ channel = Mock(spec=discord.Thread)
+ channel.name = "thread-1"
+ channel.guild = Mock()
+ channel.guild.id = 789
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert channel_type == "thread"
+ assert server_id == 789
+ assert name == "thread-1"
+
+
+def test_determine_channel_metadata_unknown():
+ """Test metadata for unknown channel type"""
+ channel = Mock()
+ channel.id = 999
+ # Ensure the mock doesn't have a 'name' attribute
+ del channel.name
+
+ channel_type, server_id, name = determine_channel_metadata(channel)
+
+ assert channel_type == "unknown"
+ assert name == "Unknown-999"
+
+
+# Tests for create_or_update_channel
+def test_create_or_update_channel_creates_new(
+ db_session, mock_text_channel, mock_guild
+):
+ """Test creating a new channel record"""
+ # Create the server first to satisfy foreign key constraint
+ create_or_update_server(db_session, mock_guild)
+
+ result = create_or_update_channel(db_session, mock_text_channel)
+
+ assert result is not None
+ assert result.id == mock_text_channel.id
+ assert result.name == mock_text_channel.name
+ assert result.channel_type == "text"
+
+
+def test_create_or_update_channel_updates_existing(db_session, mock_text_channel):
+ """Test updating an existing channel record"""
+ # Create initial channel
+ channel = DiscordChannel(
+ id=mock_text_channel.id,
+ name="old-name",
+ channel_type="text",
+ )
+ db_session.add(channel)
+ db_session.commit()
+
+ # Update with new name
+ mock_text_channel.name = "new-name"
+
+ result = create_or_update_channel(db_session, mock_text_channel)
+
+ assert result.name == "new-name"
+
+
+def test_create_or_update_channel_none_channel(db_session):
+ """Test with None channel"""
+ result = create_or_update_channel(db_session, None)
+ assert result is None
+
+
+# Tests for create_or_update_user
+def test_create_or_update_user_creates_new(db_session, mock_user):
+ """Test creating a new user record"""
+ result = create_or_update_user(db_session, mock_user)
+
+ assert result is not None
+ assert result.id == mock_user.id
+ assert result.username == mock_user.name
+ assert result.display_name == mock_user.display_name
+
+
+def test_create_or_update_user_updates_existing(db_session, mock_user):
+ """Test updating an existing user record"""
+ # Create initial user
+ user = DiscordUser(
+ id=mock_user.id,
+ username="oldname",
+ display_name="Old Display Name",
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ # Update with new data
+ mock_user.name = "newname"
+ mock_user.display_name = "New Display Name"
+
+ result = create_or_update_user(db_session, mock_user)
+
+ assert result.username == "newname"
+ assert result.display_name == "New Display Name"
+
+
+def test_create_or_update_user_none_user(db_session):
+ """Test with None user"""
+ result = create_or_update_user(db_session, None)
+ assert result is None
+
+
+# Tests for determine_message_metadata
+def test_determine_message_metadata_default():
+ """Test metadata for default message"""
+ message = Mock()
+ message.reference = None
+ message.channel = Mock()
+ # Ensure channel doesn't have parent attribute
+ del message.channel.parent
+
+ message_type, reply_to_id, thread_id = determine_message_metadata(message)
+
+ assert message_type == "default"
+ assert reply_to_id is None
+ assert thread_id is None
+
+
+def test_determine_message_metadata_reply():
+ """Test metadata for reply message"""
+ message = Mock()
+ message.reference = Mock()
+ message.reference.message_id = 123456
+ message.channel = Mock()
+
+ message_type, reply_to_id, thread_id = determine_message_metadata(message)
+
+ assert message_type == "reply"
+ assert reply_to_id == 123456
+
+
+def test_determine_message_metadata_thread():
+ """Test metadata for message in thread"""
+ message = Mock()
+ message.reference = None
+ message.channel = Mock()
+ message.channel.id = 999
+ message.channel.parent = Mock() # Has parent means it's a thread
+
+ message_type, reply_to_id, thread_id = determine_message_metadata(message)
+
+ assert thread_id == 999
+
+
+# Tests for should_track_message
+def test_should_track_message_server_disabled(db_session):
+ """Test when server has tracking disabled"""
+ server = DiscordServer(id=1, name="Server", track_messages=False)
+ channel = DiscordChannel(id=2, name="Channel", channel_type="text")
+ user = DiscordUser(id=3, username="User")
+
+ result = should_track_message(server, channel, user)
+
+ assert result is False
+
+
+def test_should_track_message_channel_disabled(db_session):
+ """Test when channel has tracking disabled"""
+ server = DiscordServer(id=1, name="Server", track_messages=True)
+ channel = DiscordChannel(
+ id=2, name="Channel", channel_type="text", track_messages=False
+ )
+ user = DiscordUser(id=3, username="User")
+
+ result = should_track_message(server, channel, user)
+
+ assert result is False
+
+
+def test_should_track_message_dm_allowed(db_session):
+ """Test DM tracking when user allows it"""
+ channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
+ user = DiscordUser(id=3, username="User", track_messages=True)
+
+ result = should_track_message(None, channel, user)
+
+ assert result is True
+
+
+def test_should_track_message_dm_not_allowed(db_session):
+ """Test DM tracking when user doesn't allow it"""
+ channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
+ user = DiscordUser(id=3, username="User", track_messages=False)
+
+ result = should_track_message(None, channel, user)
+
+ assert result is False
+
+
+def test_should_track_message_default_true(db_session):
+ """Test default tracking behavior"""
+ server = DiscordServer(id=1, name="Server", track_messages=True)
+ channel = DiscordChannel(
+ id=2, name="Channel", channel_type="text", track_messages=True
+ )
+ user = DiscordUser(id=3, username="User")
+
+ result = should_track_message(server, channel, user)
+
+ assert result is True
+
+
+# Tests for should_collect_bot_message
+@patch("memory.common.settings.DISCORD_COLLECT_BOTS", False)
+def test_should_collect_bot_message_bot_not_allowed():
+ """Test bot message collection when disabled"""
+ message = Mock()
+ message.author = Mock()
+ message.author.bot = True
+
+ result = should_collect_bot_message(message)
+
+ assert result is False
+
+
+@patch("memory.common.settings.DISCORD_COLLECT_BOTS", True)
+def test_should_collect_bot_message_bot_allowed():
+ """Test bot message collection when enabled"""
+ message = Mock()
+ message.author = Mock()
+ message.author.bot = True
+
+ result = should_collect_bot_message(message)
+
+ assert result is True
+
+
+def test_should_collect_bot_message_human():
+ """Test human message collection"""
+ message = Mock()
+ message.author = Mock()
+ message.author.bot = False
+
+ result = should_collect_bot_message(message)
+
+ assert result is True
+
+
+# Tests for sync_guild_metadata
+@patch("memory.discord.collector.make_session")
+def test_sync_guild_metadata(mock_make_session, mock_guild):
+ """Test syncing guild metadata"""
+ mock_session = Mock()
+ mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=None)
+
+ # Mock session.query().get() to return None (new server)
+ mock_session.query.return_value.get.return_value = None
+
+ # Mock channels
+ text_channel = Mock(spec=discord.TextChannel)
+ text_channel.id = 1
+ text_channel.name = "general"
+ text_channel.guild = mock_guild
+
+ voice_channel = Mock(spec=discord.VoiceChannel)
+ voice_channel.id = 2
+ voice_channel.name = "voice"
+ voice_channel.guild = mock_guild
+
+ mock_guild.channels = [text_channel, voice_channel]
+
+ sync_guild_metadata(mock_guild)
+
+ # Verify session.commit was called
+ mock_session.commit.assert_called_once()
+
+
+# Tests for MessageCollector class
+def test_message_collector_init():
+ """Test MessageCollector initialization"""
+ collector = MessageCollector()
+
+ assert collector.command_prefix == "!memory_"
+ assert collector.help_command is None
+ assert collector.intents.message_content is True
+ assert collector.intents.guilds is True
+ assert collector.intents.members is True
+ assert collector.intents.dm_messages is True
+
+
+@pytest.mark.asyncio
+async def test_on_ready():
+ """Test on_ready event handler"""
+ collector = MessageCollector()
+ collector.user = Mock()
+ collector.user.name = "TestBot"
+ collector.guilds = [Mock(), Mock()]
+ collector.sync_servers_and_channels = AsyncMock()
+
+ await collector.on_ready()
+
+ collector.sync_servers_and_channels.assert_called_once()
+
+
+@pytest.mark.asyncio
+@patch("memory.discord.collector.make_session")
+@patch("memory.discord.collector.add_discord_message")
+async def test_on_message_success(mock_add_task, mock_make_session, mock_message):
+ """Test successful message handling"""
+ mock_session = Mock()
+ mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=None)
+ mock_session.query.return_value.get.return_value = None # New entities
+
+ collector = MessageCollector()
+ await collector.on_message(mock_message)
+
+ # Verify task was queued
+ mock_add_task.delay.assert_called_once()
+ call_kwargs = mock_add_task.delay.call_args[1]
+ assert call_kwargs["message_id"] == mock_message.id
+ assert call_kwargs["channel_id"] == mock_message.channel.id
+ assert call_kwargs["author_id"] == mock_message.author.id
+ assert call_kwargs["content"] == mock_message.content
+
+
+@pytest.mark.asyncio
+@patch("memory.discord.collector.make_session")
+async def test_on_message_bot_message_filtered(mock_make_session, mock_message):
+ """Test bot message filtering"""
+ mock_message.author.bot = True
+
+ with patch(
+ "memory.discord.collector.should_collect_bot_message", return_value=False
+ ):
+ collector = MessageCollector()
+ await collector.on_message(mock_message)
+
+ # Should not create session or queue task
+ mock_make_session.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch("memory.discord.collector.make_session")
+async def test_on_message_error_handling(mock_make_session, mock_message):
+ """Test error handling in on_message"""
+ mock_make_session.side_effect = Exception("Database error")
+
+ collector = MessageCollector()
+ # Should not raise
+ await collector.on_message(mock_message)
+
+
+@pytest.mark.asyncio
+@patch("memory.discord.collector.edit_discord_message")
+async def test_on_message_edit(mock_edit_task):
+ """Test message edit handler"""
+ before = Mock()
+ after = Mock()
+ after.id = 123
+ after.content = "Edited content"
+ after.edited_at = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
+
+ collector = MessageCollector()
+ await collector.on_message_edit(before, after)
+
+ mock_edit_task.delay.assert_called_once()
+ call_kwargs = mock_edit_task.delay.call_args[1]
+ assert call_kwargs["message_id"] == 123
+ assert call_kwargs["content"] == "Edited content"
+
+
+@pytest.mark.asyncio
+async def test_on_message_edit_error_handling():
+ """Test error handling in on_message_edit"""
+ before = Mock()
+ after = Mock()
+ after.id = 123
+ after.content = "Edited"
+ after.edited_at = None # Will trigger datetime.now
+
+ with patch("memory.discord.collector.edit_discord_message") as mock_edit:
+ mock_edit.delay.side_effect = Exception("Task error")
+
+ collector = MessageCollector()
+ # Should not raise
+ await collector.on_message_edit(before, after)
+
+
+@pytest.mark.asyncio
+async def test_sync_servers_and_channels():
+ """Test syncing servers and channels"""
+ guild1 = Mock()
+ guild2 = Mock()
+
+ collector = MessageCollector()
+ collector.guilds = [guild1, guild2]
+
+ with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
+ await collector.sync_servers_and_channels()
+
+ assert mock_sync.call_count == 2
+ mock_sync.assert_any_call(guild1)
+ mock_sync.assert_any_call(guild2)
+
+
+@pytest.mark.asyncio
+@patch("memory.discord.collector.make_session")
+async def test_refresh_metadata(mock_make_session):
+ """Test metadata refresh"""
+ mock_session = Mock()
+ mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=None)
+ mock_session.query.return_value.get.return_value = None
+
+ guild = Mock()
+ guild.id = 123
+ guild.name = "Test"
+ guild.channels = []
+ guild.members = []
+
+ collector = MessageCollector()
+ collector.guilds = [guild]
+ collector.intents = Mock()
+ collector.intents.members = False
+
+ result = await collector.refresh_metadata()
+
+ assert result["servers_updated"] == 1
+ assert result["channels_updated"] == 0
+ assert result["users_updated"] == 0
+
+
+@pytest.mark.asyncio
+async def test_get_user_by_id():
+ """Test getting user by ID"""
+ user = Mock()
+ user.id = 123
+
+ collector = MessageCollector()
+ collector.get_user = Mock(return_value=user)
+
+ result = await collector.get_user(123)
+
+ assert result == user
+
+
+@pytest.mark.asyncio
+async def test_get_user_by_username():
+ """Test getting user by username"""
+ member = Mock()
+ member.name = "testuser"
+ member.display_name = "Test User"
+ member.discriminator = "1234"
+
+ guild = Mock()
+ guild.members = [member]
+
+ collector = MessageCollector()
+ collector.guilds = [guild]
+
+ result = await collector.get_user("testuser")
+
+ assert result == member
+
+
+@pytest.mark.asyncio
+async def test_get_user_not_found():
+ """Test getting non-existent user"""
+ collector = MessageCollector()
+ collector.guilds = []
+
+ with patch.object(collector, "get_user", return_value=None):
+ with patch.object(
+ collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
+ ):
+ result = await collector.get_user(999)
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_get_channel_by_name():
+ """Test getting channel by name"""
+ channel = Mock(spec=discord.TextChannel)
+ channel.name = "general"
+
+ guild = Mock()
+ guild.channels = [channel]
+
+ collector = MessageCollector()
+ collector.guilds = [guild]
+
+ result = await collector.get_channel_by_name("general")
+
+ assert result == channel
+
+
+@pytest.mark.asyncio
+async def test_get_channel_by_name_not_found():
+ """Test getting non-existent channel"""
+ guild = Mock()
+ guild.channels = []
+
+ collector = MessageCollector()
+ collector.guilds = [guild]
+
+ result = await collector.get_channel_by_name("nonexistent")
+
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_create_channel():
+ """Test creating a channel"""
+ guild = Mock()
+ guild.name = "Test Server"
+ new_channel = Mock()
+ guild.create_text_channel = AsyncMock(return_value=new_channel)
+
+ collector = MessageCollector()
+ collector.get_guild = Mock(return_value=guild)
+
+ result = await collector.create_channel("new-channel", guild_id=123)
+
+ assert result == new_channel
+ guild.create_text_channel.assert_called_once_with("new-channel")
+
+
+@pytest.mark.asyncio
+async def test_create_channel_no_guild():
+ """Test creating channel when no guild available"""
+ collector = MessageCollector()
+ collector.get_guild = Mock(return_value=None)
+ collector.guilds = []
+
+ result = await collector.create_channel("new-channel")
+
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_send_dm_success():
+ """Test sending DM successfully"""
+ user = Mock()
+ user.send = AsyncMock()
+
+ collector = MessageCollector()
+ collector.get_user = AsyncMock(return_value=user)
+
+ result = await collector.send_dm(123, "Hello!")
+
+ assert result is True
+ user.send.assert_called_once_with("Hello!")
+
+
+@pytest.mark.asyncio
+async def test_send_dm_user_not_found():
+ """Test sending DM when user not found"""
+ collector = MessageCollector()
+ collector.get_user = AsyncMock(return_value=None)
+
+ result = await collector.send_dm(123, "Hello!")
+
+ assert result is False
+
+
+@pytest.mark.asyncio
+async def test_send_dm_exception():
+ """Test sending DM with exception"""
+ user = Mock()
+ user.send = AsyncMock(side_effect=Exception("Send failed"))
+
+ collector = MessageCollector()
+ collector.get_user = AsyncMock(return_value=user)
+
+ result = await collector.send_dm(123, "Hello!")
+
+ assert result is False
+
+
+@pytest.mark.asyncio
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+async def test_send_to_channel_success():
+ """Test sending to channel successfully"""
+ channel = Mock()
+ channel.send = AsyncMock()
+
+ collector = MessageCollector()
+ collector.get_channel_by_name = AsyncMock(return_value=channel)
+
+ result = await collector.send_to_channel("general", "Announcement!")
+
+ assert result is True
+ channel.send.assert_called_once_with("Announcement!")
+
+
+@pytest.mark.asyncio
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
+async def test_send_to_channel_notifications_disabled():
+ """Test sending to channel when notifications disabled"""
+ collector = MessageCollector()
+
+ result = await collector.send_to_channel("general", "Announcement!")
+
+ assert result is False
+
+
+@pytest.mark.asyncio
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+async def test_send_to_channel_not_found():
+ """Test sending to non-existent channel"""
+ collector = MessageCollector()
+ collector.get_channel_by_name = AsyncMock(return_value=None)
+
+ result = await collector.send_to_channel("nonexistent", "Message")
+
+ assert result is False
+
+
+@pytest.mark.asyncio
+@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
+async def test_run_collector():
+ """Test running the collector"""
+ from memory.discord.collector import run_collector
+
+ with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
+ mock_collector = Mock()
+ mock_collector.start = AsyncMock()
+ mock_collector_class.return_value = mock_collector
+
+ await run_collector()
+
+ mock_collector.start.assert_called_once_with("test_token")
+
+
+@pytest.mark.asyncio
+@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
+async def test_run_collector_no_token():
+ """Test running collector without token"""
+ from memory.discord.collector import run_collector
+
+ # Should return early without raising
+ await run_collector()
diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py
new file mode 100644
index 0000000..5ce5765
--- /dev/null
+++ b/tests/memory/workers/tasks/test_discord_tasks.py
@@ -0,0 +1,607 @@
+import pytest
+from datetime import datetime, timezone
+from unittest.mock import Mock, patch
+
+from memory.common.db.models import (
+ DiscordMessage,
+ DiscordUser,
+ DiscordServer,
+ DiscordChannel,
+)
+from memory.workers.tasks import discord
+
+
+@pytest.fixture
+def mock_discord_user(db_session):
+ """Create a Discord user for testing."""
+ user = DiscordUser(
+ id=123456789,
+ username="testuser",
+ ignore_messages=False,
+ )
+ db_session.add(user)
+ db_session.commit()
+ return user
+
+
+@pytest.fixture
+def mock_discord_server(db_session):
+ """Create a Discord server for testing."""
+ server = DiscordServer(
+ id=987654321,
+ name="Test Server",
+ ignore_messages=False,
+ )
+ db_session.add(server)
+ db_session.commit()
+ return server
+
+
+@pytest.fixture
+def mock_discord_channel(db_session, mock_discord_server):
+ """Create a Discord channel for testing."""
+ channel = DiscordChannel(
+ id=111222333,
+ name="test-channel",
+ channel_type="text",
+ server_id=mock_discord_server.id,
+ ignore_messages=False,
+ )
+ db_session.add(channel)
+ db_session.commit()
+ return channel
+
+
+@pytest.fixture
+def sample_message_data(mock_discord_user, mock_discord_channel):
+ """Sample message data for testing."""
+ return {
+ "message_id": 999888777,
+ "channel_id": mock_discord_channel.id,
+ "author_id": mock_discord_user.id,
+ "content": "This is a test Discord message with enough content to be processed.",
+ "sent_at": "2024-01-01T12:00:00Z",
+ "server_id": None,
+ "message_reference_id": None,
+ }
+
+
+def test_get_prev_returns_previous_messages(
+ db_session, mock_discord_user, mock_discord_channel
+):
+ """Test that get_prev returns previous messages in order."""
+ # Create previous messages
+ msg1 = DiscordMessage(
+ message_id=1,
+ channel_id=mock_discord_channel.id,
+ discord_user_id=mock_discord_user.id,
+ content="First message",
+ sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=b"hash1" + bytes(26),
+ )
+ msg2 = DiscordMessage(
+ message_id=2,
+ channel_id=mock_discord_channel.id,
+ discord_user_id=mock_discord_user.id,
+ content="Second message",
+ sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=b"hash2" + bytes(26),
+ )
+ msg3 = DiscordMessage(
+ message_id=3,
+ channel_id=mock_discord_channel.id,
+ discord_user_id=mock_discord_user.id,
+ content="Third message",
+ sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=b"hash3" + bytes(26),
+ )
+ db_session.add_all([msg1, msg2, msg3])
+ db_session.commit()
+
+ # Get previous messages before 10:15
+ result = discord.get_prev(
+ db_session,
+ mock_discord_channel.id,
+ datetime(2024, 1, 1, 10, 15, 0, tzinfo=timezone.utc),
+ )
+
+ assert len(result) == 3
+ assert result[0] == "testuser: First message"
+ assert result[1] == "testuser: Second message"
+ assert result[2] == "testuser: Third message"
+
+
+def test_get_prev_limits_context_window(
+ db_session, mock_discord_user, mock_discord_channel
+):
+ """Test that get_prev respects DISCORD_CONTEXT_WINDOW setting."""
+ # Create 15 messages (more than the default context window of 10)
+ for i in range(15):
+ msg = DiscordMessage(
+ message_id=i,
+ channel_id=mock_discord_channel.id,
+ discord_user_id=mock_discord_user.id,
+ content=f"Message {i}",
+ sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=f"hash{i}".encode() + bytes(27),
+ )
+ db_session.add(msg)
+ db_session.commit()
+
+ result = discord.get_prev(
+ db_session,
+ mock_discord_channel.id,
+ datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
+ )
+
+ # Should only return last 10 messages
+ assert len(result) == 10
+ assert result[0] == "testuser: Message 5" # Oldest in window
+ assert result[-1] == "testuser: Message 14" # Most recent
+
+
+def test_get_prev_empty_channel(db_session, mock_discord_channel):
+ """Test get_prev with no previous messages."""
+ result = discord.get_prev(
+ db_session,
+ mock_discord_channel.id,
+ datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
+ )
+
+ assert result == []
+
+
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_should_process_normal_message(
+ db_session, mock_discord_user, mock_discord_server, mock_discord_channel
+):
+ """Test should_process returns True for normal messages."""
+ message = DiscordMessage(
+ message_id=1,
+ channel_id=mock_discord_channel.id,
+ discord_user_id=mock_discord_user.id,
+ server_id=mock_discord_server.id,
+ content="Test",
+ sent_at=datetime.now(timezone.utc),
+ modality="text",
+ sha256=b"hash" + bytes(27),
+ )
+ db_session.add(message)
+ db_session.commit()
+ db_session.refresh(message)
+
+ assert discord.should_process(message) is True
+
+
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
+def test_should_process_disabled():
+ """Test should_process returns False when processing is disabled."""
+ message = Mock()
+ assert discord.should_process(message) is False
+
+
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
+def test_should_process_notifications_disabled():
+ """Test should_process returns False when notifications are disabled."""
+ message = Mock()
+ assert discord.should_process(message) is False
+
+
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_should_process_server_ignored(
+ db_session, mock_discord_user, mock_discord_channel
+):
+ """Test should_process returns False when server has ignore_messages=True."""
+ server = DiscordServer(
+ id=123,
+ name="Ignored Server",
+ ignore_messages=True,
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ message = DiscordMessage(
+ message_id=1,
+ channel_id=mock_discord_channel.id,
+ discord_user_id=mock_discord_user.id,
+ server_id=server.id,
+ content="Test",
+ sent_at=datetime.now(timezone.utc),
+ modality="text",
+ sha256=b"hash" + bytes(27),
+ )
+ db_session.add(message)
+ db_session.commit()
+ db_session.refresh(message)
+
+ assert discord.should_process(message) is False
+
+
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_should_process_channel_ignored(
+ db_session, mock_discord_user, mock_discord_server
+):
+ """Test should_process returns False when channel has ignore_messages=True."""
+ channel = DiscordChannel(
+ id=456,
+ name="ignored-channel",
+ channel_type="text",
+ server_id=mock_discord_server.id,
+ ignore_messages=True,
+ )
+ db_session.add(channel)
+ db_session.commit()
+
+ message = DiscordMessage(
+ message_id=1,
+ channel_id=channel.id,
+ discord_user_id=mock_discord_user.id,
+ server_id=mock_discord_server.id,
+ content="Test",
+ sent_at=datetime.now(timezone.utc),
+ modality="text",
+ sha256=b"hash" + bytes(27),
+ )
+ db_session.add(message)
+ db_session.commit()
+ db_session.refresh(message)
+
+ assert discord.should_process(message) is False
+
+
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_should_process_user_ignored(
+ db_session, mock_discord_server, mock_discord_channel
+):
+ """Test should_process returns False when user has ignore_messages=True."""
+ user = DiscordUser(
+ id=789,
+ username="ignoreduser",
+ ignore_messages=True,
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ message = DiscordMessage(
+ message_id=1,
+ channel_id=mock_discord_channel.id,
+ discord_user_id=user.id,
+ server_id=mock_discord_server.id,
+ content="Test",
+ sent_at=datetime.now(timezone.utc),
+ modality="text",
+ sha256=b"hash" + bytes(27),
+ )
+ db_session.add(message)
+ db_session.commit()
+ db_session.refresh(message)
+
+ assert discord.should_process(message) is False
+
+
+def test_add_discord_message_success(db_session, sample_message_data, qdrant):
+ """Test successful Discord message addition."""
+ result = discord.add_discord_message(**sample_message_data)
+
+ assert result["status"] == "processed"
+ assert "discordmessage_id" in result
+
+ # Verify the message was created in the database
+ message = (
+ db_session.query(DiscordMessage)
+ .filter_by(message_id=sample_message_data["message_id"])
+ .first()
+ )
+ assert message is not None
+ assert message.content == sample_message_data["content"]
+ assert message.message_type == "default"
+ assert message.reply_to_message_id is None
+
+
+def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
+ """Test adding a Discord message that is a reply."""
+ sample_message_data["message_reference_id"] = 111222333
+
+ discord.add_discord_message(**sample_message_data)
+
+ message = (
+ db_session.query(DiscordMessage)
+ .filter_by(message_id=sample_message_data["message_id"])
+ .first()
+ )
+ assert message.message_type == "reply"
+ assert message.reply_to_message_id == 111222333
+
+
+def test_add_discord_message_already_exists(db_session, sample_message_data, qdrant):
+ """Test adding a message that already exists."""
+ # Add the message once
+ discord.add_discord_message(**sample_message_data)
+
+ # Try to add it again
+ result = discord.add_discord_message(**sample_message_data)
+
+ assert result["status"] == "already_exists"
+ assert result["message_id"] == sample_message_data["message_id"]
+
+ # Verify no duplicate was created
+ messages = (
+ db_session.query(DiscordMessage)
+ .filter_by(message_id=sample_message_data["message_id"])
+ .all()
+ )
+ assert len(messages) == 1
+
+
+def test_add_discord_message_with_context(
+ db_session, sample_message_data, mock_discord_user, qdrant
+):
+ """Test that message is added successfully when previous messages exist."""
+ # Add a previous message
+ prev_msg = DiscordMessage(
+ message_id=111111,
+ channel_id=sample_message_data["channel_id"],
+ discord_user_id=mock_discord_user.id,
+ content="Previous message",
+ sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=b"prev" + bytes(28),
+ )
+ db_session.add(prev_msg)
+ db_session.commit()
+
+ result = discord.add_discord_message(**sample_message_data)
+
+ message = (
+ db_session.query(DiscordMessage)
+ .filter_by(message_id=sample_message_data["message_id"])
+ .first()
+ )
+ assert message is not None
+ assert result["status"] == "processed"
+
+
+@patch("memory.workers.tasks.discord.process_discord_message")
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
+@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+def test_add_discord_message_triggers_processing(
+ mock_process,
+ db_session,
+ sample_message_data,
+ mock_discord_server,
+ mock_discord_channel,
+ qdrant,
+):
+ """Test that add_discord_message triggers process_discord_message when conditions are met."""
+ mock_process.delay = Mock()
+ sample_message_data["server_id"] = mock_discord_server.id
+
+ discord.add_discord_message(**sample_message_data)
+
+ # Verify process_discord_message.delay was called
+ mock_process.delay.assert_called_once()
+
+
+@patch("memory.workers.tasks.discord.process_discord_message")
+@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
+def test_add_discord_message_no_processing_when_disabled(
+ mock_process, db_session, sample_message_data, qdrant
+):
+ """Test that process_discord_message is not called when processing is disabled."""
+ mock_process.delay = Mock()
+
+ discord.add_discord_message(**sample_message_data)
+
+ mock_process.delay.assert_not_called()
+
+
+def test_edit_discord_message_success(db_session, sample_message_data, qdrant):
+ """Test successful Discord message edit."""
+ # First add the message
+ discord.add_discord_message(**sample_message_data)
+
+ # Edit it
+ new_content = (
+ "This is the edited content with enough text to be meaningful and processed."
+ )
+ edited_at = "2024-01-01T13:00:00Z"
+
+ result = discord.edit_discord_message(
+ sample_message_data["message_id"],
+ new_content,
+ edited_at,
+ )
+
+ assert result["status"] == "processed"
+
+ # Verify the message was updated
+ message = (
+ db_session.query(DiscordMessage)
+ .filter_by(message_id=sample_message_data["message_id"])
+ .first()
+ )
+ assert message.content == new_content
+ assert message.edited_at is not None
+
+
+def test_edit_discord_message_not_found(db_session):
+ """Test editing a message that doesn't exist."""
+ result = discord.edit_discord_message(
+ 999999,
+ "New content",
+ "2024-01-01T13:00:00Z",
+ )
+
+ assert result["status"] == "error"
+ assert result["error"] == "Message not found"
+ assert result["message_id"] == 999999
+
+
+def test_edit_discord_message_updates_context(
+ db_session, sample_message_data, mock_discord_user, qdrant
+):
+ """Test that editing a message works correctly."""
+ # Add previous message and the message to be edited
+ prev_msg = DiscordMessage(
+ message_id=111111,
+ channel_id=sample_message_data["channel_id"],
+ discord_user_id=mock_discord_user.id,
+ content="Previous message",
+ sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=b"prev" + bytes(28),
+ )
+ db_session.add(prev_msg)
+ db_session.commit()
+
+ discord.add_discord_message(**sample_message_data)
+
+ # Edit the message
+ result = discord.edit_discord_message(
+ sample_message_data["message_id"],
+ "Edited content that should have context updated properly.",
+ "2024-01-01T13:00:00Z",
+ )
+
+ # Verify message was updated
+ message = (
+ db_session.query(DiscordMessage)
+ .filter_by(message_id=sample_message_data["message_id"])
+ .first()
+ )
+ assert (
+ message.content == "Edited content that should have context updated properly."
+ )
+ assert result["status"] == "processed"
+
+
+def test_process_discord_message_success(db_session, sample_message_data, qdrant):
+ """Test processing a Discord message."""
+ # Add a message first
+ add_result = discord.add_discord_message(**sample_message_data)
+ message_id = add_result["discordmessage_id"]
+
+ # Process it
+ result = discord.process_discord_message(message_id)
+
+ assert result["status"] == "processed"
+ assert result["message_id"] == message_id
+
+
+def test_process_discord_message_not_found(db_session):
+ """Test processing a message that doesn't exist."""
+ result = discord.process_discord_message(999999)
+
+ assert result["status"] == "error"
+ assert result["error"] == "Message not found"
+ assert result["message_id"] == 999999
+
+
+@pytest.mark.parametrize(
+ "sent_at_str,expected_hour",
+ [
+ ("2024-01-01T12:00:00Z", 12),
+ ("2024-01-01T00:00:00+00:00", 0),
+ ("2024-01-01T23:59:59Z", 23),
+ ],
+)
+def test_add_discord_message_datetime_parsing(
+ db_session, sample_message_data, sent_at_str, expected_hour, qdrant
+):
+ """Test that various datetime formats are parsed correctly."""
+ sample_message_data["sent_at"] = sent_at_str
+
+ discord.add_discord_message(**sample_message_data)
+
+ message = (
+ db_session.query(DiscordMessage)
+ .filter_by(message_id=sample_message_data["message_id"])
+ .first()
+ )
+ assert message.sent_at.hour == expected_hour
+
+
+def test_add_discord_message_unique_hash(db_session, sample_message_data, qdrant):
+ """Test that message hash includes message_id for uniqueness."""
+ # Add first message
+ discord.add_discord_message(**sample_message_data)
+
+ # Try to add another message with same content but different message_id
+ sample_message_data["message_id"] = 888777666
+
+ result = discord.add_discord_message(**sample_message_data)
+
+ # Should succeed because hash includes message_id
+ assert result["status"] == "processed"
+
+ # Verify both messages exist
+ messages = (
+ db_session.query(DiscordMessage)
+ .filter_by(content=sample_message_data["content"])
+ .all()
+ )
+ assert len(messages) == 2
+
+
+def test_get_prev_only_returns_messages_from_same_channel(
+ db_session, mock_discord_user, mock_discord_server
+):
+ """Test that get_prev only returns messages from the specified channel."""
+ # Create two channels
+ channel1 = DiscordChannel(
+ id=111,
+ name="channel-1",
+ channel_type="text",
+ server_id=mock_discord_server.id,
+ )
+ channel2 = DiscordChannel(
+ id=222,
+ name="channel-2",
+ channel_type="text",
+ server_id=mock_discord_server.id,
+ )
+ db_session.add_all([channel1, channel2])
+ db_session.commit()
+
+ # Add messages to both channels
+ msg1 = DiscordMessage(
+ message_id=1,
+ channel_id=channel1.id,
+ discord_user_id=mock_discord_user.id,
+ content="Message in channel 1",
+ sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=b"hash1" + bytes(26),
+ )
+ msg2 = DiscordMessage(
+ message_id=2,
+ channel_id=channel2.id,
+ discord_user_id=mock_discord_user.id,
+ content="Message in channel 2",
+ sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
+ modality="text",
+ sha256=b"hash2" + bytes(26),
+ )
+ db_session.add_all([msg1, msg2])
+ db_session.commit()
+
+ # Get previous messages for channel 1
+ result = discord.get_prev(
+ db_session,
+ channel1.id, # type: ignore
+ datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
+ )
+
+ # Should only return message from channel 1
+ assert len(result) == 1
+ assert "Message in channel 1" in result[0]
+ assert "Message in channel 2" not in result[0]