move to general LLM providers

This commit is contained in:
Daniel O'Connell 2025-10-13 03:23:20 +02:00
parent 08d17c28dd
commit 99d3843f47
23 changed files with 3844 additions and 400 deletions

View File

@ -26,9 +26,12 @@ def upgrade() -> None:
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("member_count", sa.Integer(), nullable=True),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"track_messages", sa.Boolean(), server_default="true", nullable=False
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
@ -56,7 +59,12 @@ def upgrade() -> None:
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("channel_type", sa.Text(), nullable=False),
sa.Column("track_messages", sa.Boolean(), nullable=True),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
@ -84,9 +92,12 @@ def upgrade() -> None:
sa.Column("username", sa.Text(), nullable=False),
sa.Column("display_name", sa.Text(), nullable=True),
sa.Column("system_user_id", sa.Integer(), nullable=True),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"allow_dm_tracking", sa.Boolean(), server_default="true", nullable=False
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),

View File

@ -5,7 +5,7 @@ alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.18.1
anthropic==0.69.0
openai==1.25.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0

View File

@ -40,7 +40,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
try:
response = await asyncio.to_thread(
llms.call,
llms.summarize,
prompt,
settings.RANKER_MODEL,
images=images,

View File

@ -15,6 +15,7 @@ SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
DISCORD_ROOT = "memory.workers.tasks.discord"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"

View File

@ -127,7 +127,15 @@ class EmailAccount(Base):
)
class DiscordServer(Base):
class MessageProcessor:
track_messages = Column(Boolean, nullable=False, server_default="true")
ignore_messages = Column(Boolean, nullable=True, default=False)
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
class DiscordServer(Base, MessageProcessor):
"""Discord server configuration and metadata"""
__tablename__ = "discord_servers"
@ -138,7 +146,6 @@ class DiscordServer(Base):
member_count = Column(Integer)
# Collection settings
track_messages = Column(Boolean, nullable=False, server_default="true")
last_sync_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
@ -152,7 +159,7 @@ class DiscordServer(Base):
)
class DiscordChannel(Base):
class DiscordChannel(Base, MessageProcessor):
"""Discord channel metadata and configuration"""
__tablename__ = "discord_channels"
@ -163,7 +170,6 @@ class DiscordChannel(Base):
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
# Collection settings (null = inherit from server)
track_messages = Column(Boolean, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
@ -171,7 +177,7 @@ class DiscordChannel(Base):
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
class DiscordUser(Base):
class DiscordUser(Base, MessageProcessor):
"""Discord user metadata and preferences"""
__tablename__ = "discord_users"
@ -184,7 +190,6 @@ class DiscordUser(Base):
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
# Basic DM settings
allow_dm_tracking = Column(Boolean, nullable=False, server_default="true")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@ -1,122 +0,0 @@
import logging
import base64
import io
from typing import Any
from PIL import Image
from memory.common import settings, tokens
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """
You are a helpful assistant that creates concise summaries and identifies key topics.
"""
def encode_image(image: Image.Image) -> str:
"""Encode PIL Image to base64 string."""
buffer = io.BytesIO()
# Convert to RGB if necessary (for RGBA, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def call_openai(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call OpenAI API for summarization."""
import openai
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
try:
user_content: Any = [{"type": "text", "text": prompt}]
if images:
for image in images:
encoded_image = encode_image(image)
user_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
}
)
response = client.chat.completions.create(
model=model.split("/")[1],
messages=[
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": user_content},
],
temperature=0.3,
max_tokens=2048,
)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def call_anthropic(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call Anthropic API for summarization."""
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
try:
# Prepare the message content
content: Any = [{"type": "text", "text": prompt}]
if images:
# Add images if provided
for image in images:
encoded_image = encode_image(image)
content.append(
{ # type: ignore
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": encoded_image,
},
}
)
response = client.messages.create(
model=model.split("/")[1],
messages=[{"role": "user", "content": content}], # type: ignore
system=system_prompt,
temperature=0.3,
max_tokens=2048,
)
return response.content[0].text
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def call(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
if model.startswith("anthropic"):
return call_anthropic(prompt, model, images, system_prompt)
return call_openai(prompt, model, images, system_prompt)
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content

View File

@ -0,0 +1,79 @@
"""LLM provider module for unified LLM access."""
# Legacy imports for backwards compatibility
import logging
from PIL import Image
# New provider system
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
LLMSettings,
Message,
MessageContent,
MessageRole,
StreamEvent,
TextContent,
ThinkingContent,
ToolDefinition,
ToolResultContent,
ToolUseContent,
create_provider,
)
from memory.common import tokens
__all__ = [
"BaseLLMProvider",
"Message",
"MessageRole",
"MessageContent",
"TextContent",
"ImageContent",
"ToolUseContent",
"ToolResultContent",
"ThinkingContent",
"ToolDefinition",
"StreamEvent",
"LLMSettings",
"create_provider",
]
logger = logging.getLogger(__name__)
def summarize(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = "",
) -> str:
provider = create_provider(model=model)
try:
# Build message content
content: list[MessageContent] = [TextContent(text=prompt)]
for image in images:
content.append(ImageContent(image=image))
messages = [Message(role=MessageRole.USER, content=content)]
settings_obj = LLMSettings(temperature=0.3, max_tokens=2048)
res = provider.run_with_tools(
messages=messages,
system_prompt=system_prompt
or "You are a helpful assistant that creates concise summaries and identifies key topics.",
settings=settings_obj,
tools={},
)
return res.response or ""
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content

View File

@ -0,0 +1,451 @@
"""Anthropic LLM provider implementation."""
import json
import logging
from typing import Any, AsyncIterator, Iterator, Optional
import anthropic
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
LLMSettings,
Message,
MessageRole,
StreamEvent,
ToolDefinition,
ToolUseContent,
ThinkingContent,
TextContent,
)
logger = logging.getLogger(__name__)
class AnthropicProvider(BaseLLMProvider):
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
# Models that support extended thinking
THINKING_MODELS = {
"claude-opus-4",
"claude-opus-4-1",
"claude-sonnet-4-0",
"claude-sonnet-3-7",
"claude-sonnet-4-5",
}
def __init__(self, api_key: str, model: str, enable_thinking: bool = False):
"""
Initialize the Anthropic provider.
Args:
api_key: Anthropic API key
model: Model identifier
enable_thinking: Enable extended thinking for supported models
"""
super().__init__(api_key, model)
self.enable_thinking = enable_thinking
self._async_client: Optional[anthropic.AsyncAnthropic] = None
def _initialize_client(self) -> anthropic.Anthropic:
"""Initialize the Anthropic client."""
return anthropic.Anthropic(api_key=self.api_key)
@property
def async_client(self) -> anthropic.AsyncAnthropic:
"""Lazy-load the async client."""
if self._async_client is None:
self._async_client = anthropic.AsyncAnthropic(api_key=self.api_key)
return self._async_client
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
"""Convert ImageContent to Anthropic's base64 source format."""
encoded_image = self.encode_image(content.image)
return {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": encoded_image,
},
}
def _convert_message(self, message: Message) -> dict[str, Any]:
converted = message.to_dict()
if converted["role"] == MessageRole.ASSISTANT and isinstance(
converted["content"], list
):
content = sorted(
converted["content"], key=lambda x: x["type"] != "thinking"
)
return converted | {"content": content}
return converted
def _should_include_message(self, message: Message) -> bool:
"""Filter out system messages (handled separately in Anthropic)."""
return message.role != MessageRole.SYSTEM
def _supports_thinking(self) -> bool:
"""Check if the current model supports extended thinking."""
model_lower = self.model.lower()
return any(supported in model_lower for supported in self.THINKING_MODELS)
def _build_request_kwargs(
self,
messages: list[Message],
system_prompt: str | None,
tools: list[ToolDefinition] | None,
settings: LLMSettings,
) -> dict[str, Any]:
"""Build common request kwargs for API calls."""
anthropic_messages = self._convert_messages(messages)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": anthropic_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
}
# Only include top_p if explicitly set
if settings.top_p is not None:
kwargs["top_p"] = settings.top_p
if system_prompt:
kwargs["system"] = system_prompt
if settings.stop_sequences:
kwargs["stop_sequences"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
# Enable extended thinking if requested and model supports it
if self.enable_thinking and self._supports_thinking():
thinking_budget = min(10000, settings.max_tokens - 1024)
if thinking_budget >= 1024:
kwargs["thinking"] = {
"type": "enabled",
"budget_tokens": thinking_budget,
}
# When thinking is enabled: temperature must be 1, can't use top_p
kwargs["temperature"] = 1.0
kwargs.pop("top_p", None)
return kwargs
def _handle_stream_event(
self, event: Any, current_tool_use: dict[str, Any] | None
) -> tuple[StreamEvent | None, dict[str, Any] | None]:
"""
Handle a streaming event and return StreamEvent and updated tool state.
Returns:
Tuple of (StreamEvent or None, updated current_tool_use or None)
"""
event_type = getattr(event, "type", None)
# Handle error events
if event_type == "error":
error = getattr(event, "error", None)
error_msg = str(error) if error else "Unknown error"
return StreamEvent(type="error", data=error_msg), current_tool_use
if event_type == "content_block_start":
block = getattr(event, "content_block", None)
if not block:
return None, current_tool_use
block_type = getattr(block, "type", None)
# Handle various tool types (tool_use, mcp_tool_use, server_tool_use)
if block_type in ("tool_use", "mcp_tool_use", "server_tool_use"):
# In content_block_start, input may already be present (empty dict)
block_input = getattr(block, "input", None)
current_tool_use = {
"id": getattr(block, "id", ""),
"name": getattr(block, "name", ""),
"input": block_input if block_input is not None else "",
"server_name": getattr(block, "server_name", None),
"is_server_call": block_type != "tool_use",
}
# Handle tool result blocks
elif hasattr(block, "tool_use_id"):
tool_result = {
"id": getattr(block, "tool_use_id", ""),
"result": getattr(block, "content", ""),
}
return StreamEvent(
type="tool_result", data=tool_result
), current_tool_use
# For non-tool blocks (text, thinking), we don't need to track state
return None, current_tool_use
elif event_type == "content_block_delta":
delta = getattr(event, "delta", None)
if not delta:
return None, current_tool_use
delta_type = getattr(delta, "type", None)
if delta_type == "text_delta":
text = getattr(delta, "text", "")
return StreamEvent(type="text", data=text), current_tool_use
elif delta_type == "thinking_delta":
thinking = getattr(delta, "thinking", "")
return StreamEvent(type="thinking", data=thinking), current_tool_use
elif delta_type == "signature_delta":
# Handle thinking signature for extended thinking
signature = getattr(delta, "signature", "")
return StreamEvent(
type="thinking", signature=signature
), current_tool_use
elif delta_type == "input_json_delta":
if current_tool_use is None:
# Edge case: received input_json_delta without tool_use start
logger.warning("Received input_json_delta without tool_use context")
return None, None
# Only accumulate if input is still a string (being built up)
if isinstance(current_tool_use.get("input"), str):
partial_json = getattr(delta, "partial_json", "")
current_tool_use["input"] += partial_json
# else: input was already set as a dict in content_block_start
return None, current_tool_use
elif event_type == "content_block_stop":
if current_tool_use:
# Use the parsed input from the content block if available
# This handles empty inputs {} more reliably than parsing
content_block = getattr(event, "content_block", None)
if content_block and hasattr(content_block, "input"):
current_tool_use["input"] = content_block.input
else:
# Fallback: parse accumulated JSON string
input_str = current_tool_use.get("input", "")
if isinstance(input_str, str):
# Need to parse the accumulated string
if not input_str or input_str.isspace():
# Empty or whitespace-only input
current_tool_use["input"] = {}
else:
try:
current_tool_use["input"] = json.loads(input_str)
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse tool input '{input_str}': {e}"
)
current_tool_use["input"] = {}
# else: input is already parsed
tool_data = {
"id": current_tool_use.get("id", ""),
"name": current_tool_use.get("name", ""),
"input": current_tool_use.get("input", {}),
}
# Include server info if present
if current_tool_use.get("server_name"):
tool_data["server_name"] = current_tool_use["server_name"]
if current_tool_use.get("is_server_call"):
tool_data["is_server_call"] = current_tool_use["is_server_call"]
return StreamEvent(type="tool_use", data=tool_data), None
elif event_type == "message_delta":
delta = getattr(event, "delta", None)
if delta:
stop_reason = getattr(delta, "stop_reason", None)
if stop_reason == "max_tokens":
return StreamEvent(
type="error", data="Max tokens reached"
), current_tool_use
# Handle token usage information
usage = getattr(event, "usage", None)
if usage:
usage_data = {
"input_tokens": getattr(usage, "input_tokens", 0),
"output_tokens": getattr(usage, "output_tokens", 0),
"cache_creation_input_tokens": getattr(
usage, "cache_creation_input_tokens", None
),
"cache_read_input_tokens": getattr(
usage, "cache_read_input_tokens", None
),
}
# Could emit this as a separate event type if needed
logger.debug(f"Token usage: {usage_data}")
return None, current_tool_use
elif event_type == "message_stop":
# Final event - clean up any pending state
if current_tool_use:
logger.warning(
f"Message ended with incomplete tool use: {current_tool_use}"
)
return StreamEvent(type="done"), None
# Unknown event type - log but don't fail
if event_type and event_type not in (
"message_start",
"message_delta",
"content_block_start",
"content_block_delta",
"content_block_stop",
"message_stop",
):
logger.debug(f"Unknown event type: {event_type}")
return None, current_tool_use
def generate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
response = self.client.messages.create(**kwargs)
return "".join(
block.text for block in response.content if block.type == "text"
)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def stream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
with self.client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None
for event in stream:
stream_event, current_tool_use = self._handle_stream_event(
event, current_tool_use
)
if stream_event:
yield stream_event
except Exception as e:
logger.error(f"Anthropic streaming error: {e}")
yield StreamEvent(type="error", data=str(e))
async def agenerate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
response = await self.async_client.messages.create(**kwargs)
return "".join(
block.text for block in response.content if block.type == "text"
)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
async def astream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
async with self.async_client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None
async for event in stream:
stream_event, current_tool_use = self._handle_stream_event(
event, current_tool_use
)
if stream_event:
yield stream_event
except Exception as e:
logger.error(f"Anthropic streaming error: {e}")
yield StreamEvent(type="error", data=str(e))
def stream_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Iterator[StreamEvent]:
if max_iterations <= 0:
return
response = TextContent(text="")
thinking = ThinkingContent(thinking="", signature="")
for event in self.stream(
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
settings=settings,
):
if event.type == "text":
response.text += event.data
yield event
elif event.type == "thinking" and event.signature:
thinking.signature = event.signature
elif event.type == "thinking":
thinking.thinking += event.data
yield event
elif event.type == "tool_use":
yield event
tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
messages.append(Message.user(tool_result=tool_result))
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
)
elif event.type == "tool_result":
yield event
elif event.type == "error":
logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}")

View File

@ -0,0 +1,561 @@
"""Base classes and types for LLM providers."""
import base64
import io
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
from PIL import Image
from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
logger = logging.getLogger(__name__)
class MessageRole(str, Enum):
"""Message roles for chat history."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class TextContent:
"""Text content in a message."""
type: Literal["text"] = "text"
text: str = ""
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {"type": "text", "text": self.text}
@dataclass
class ImageContent:
"""Image content in a message."""
type: Literal["image"] = "image"
image: Image.Image = None # type: ignore
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
# Note: Image will be encoded by provider-specific implementation
return {"type": "image", "image": self.image}
@dataclass
class ToolUseContent:
"""Tool use request from the assistant."""
type: Literal["tool_use"] = "tool_use"
id: str = ""
name: str = ""
input: dict[str, Any] = None # type: ignore
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {
"type": "tool_use",
"id": self.id,
"name": self.name,
"input": self.input,
}
@dataclass
class ToolResultContent:
"""Tool result from tool execution."""
type: Literal["tool_result"] = "tool_result"
tool_use_id: str = ""
content: str = ""
is_error: bool = False
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {
"type": "tool_result",
"tool_use_id": self.tool_use_id,
"content": self.content,
"is_error": self.is_error,
}
@dataclass
class ThinkingContent:
"""Thinking/reasoning content from the assistant (extended thinking)."""
type: Literal["thinking"] = "thinking"
thinking: str = ""
signature: str | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {
"type": "thinking",
"thinking": self.thinking,
"signature": self.signature,
}
MessageContent = Union[
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
]
@dataclass
class Turn:
"""A turn in the conversation."""
response: str | None
thinking: str | None
tool_calls: dict[str, ToolResult] | None
@dataclass
class Message:
"""A message in the conversation history."""
role: MessageRole
content: Union[str, list[MessageContent]]
def to_dict(self) -> dict[str, Any]:
"""Convert message to dictionary format."""
if isinstance(self.content, str):
return {"role": self.role.value, "content": self.content}
content_list = [item.to_dict() for item in self.content]
return {"role": self.role.value, "content": content_list}
@staticmethod
def assistant(
text: TextContent | None = None,
thinking: ThinkingContent | None = None,
tool_use: ToolUseContent | None = None,
) -> "Message":
parts = []
if text:
parts.append(text)
if thinking:
parts.append(thinking)
if tool_use:
parts.append(tool_use)
return Message(role=MessageRole.ASSISTANT, content=parts)
@staticmethod
def user(
text: str | None = None, tool_result: ToolResultContent | None = None
) -> "Message":
parts = []
if text:
parts.append(TextContent(text=text))
if tool_result:
parts.append(tool_result)
return Message(role=MessageRole.USER, content=parts)
@dataclass
class StreamEvent:
"""An event from the streaming response."""
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
data: Any = None
signature: str | None = None
@dataclass
class LLMSettings:
"""Settings for LLM API calls."""
temperature: float = 0.7
max_tokens: int = 2048
# Don't set by default - some models don't allow both temp and top_p
top_p: float | None = None
stop_sequences: list[str] | None = None
stream: bool = False
class BaseLLMProvider(ABC):
"""Base class for LLM providers."""
def __init__(self, api_key: str, model: str):
"""
Initialize the LLM provider.
Args:
api_key: API key for the provider
model: Model identifier
"""
self.api_key = api_key
self.model = model
self._client: Any = None
@abstractmethod
def _initialize_client(self) -> Any:
"""Initialize the provider-specific client."""
pass
@property
def client(self) -> Any:
"""Lazy-load the client."""
if self._client is None:
self._client = self._initialize_client()
return self._client
def execute_tool(
self,
tool_call: ToolCall,
tool_handlers: dict[str, ToolDefinition],
) -> ToolResultContent:
"""
Execute a tool call.
Args:
tool_call: Tool call
tool_handlers: Dict mapping tool names to handler functions
Returns:
ToolResultContent with result or error
"""
name = tool_call.get("name")
tool_use_id = tool_call.get("id")
input = tool_call.get("input")
if not name:
return ToolResultContent(
tool_use_id=tool_use_id,
content="Tool name missing",
is_error=True,
)
if not (tool := tool_handlers.get(name)):
return ToolResultContent(
tool_use_id=tool_use_id,
content=f"Tool '{name}' not found",
is_error=True,
)
try:
return ToolResultContent(
tool_use_id=tool_use_id,
content=tool(input),
is_error=False,
)
except Exception as e:
logger.error(f"Tool '{name}' failed: {e}", exc_info=True)
return ToolResultContent(
tool_use_id=tool_use_id,
content=str(e),
is_error=True,
)
def encode_image(self, image: Image.Image) -> str:
"""
Encode PIL Image to base64 string.
Args:
image: PIL Image to encode
Returns:
Base64 encoded string
"""
buffer = io.BytesIO()
# Convert to RGB if necessary (for RGBA, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
"""Convert TextContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
"""Convert ImageContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
"""Convert ToolUseContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_tool_result_content(
self, content: ToolResultContent
) -> dict[str, Any]:
"""Convert ToolResultContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_thinking_content(self, content: ThinkingContent) -> dict[str, Any]:
"""Convert ThinkingContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
"""
Convert a MessageContent item to provider format.
Dispatches to type-specific converters that can be overridden.
"""
if isinstance(content, TextContent):
return self._convert_text_content(content)
elif isinstance(content, ImageContent):
return self._convert_image_content(content)
elif isinstance(content, ToolUseContent):
return self._convert_tool_use_content(content)
elif isinstance(content, ToolResultContent):
return self._convert_tool_result_content(content)
elif isinstance(content, ThinkingContent):
return self._convert_thinking_content(content)
else:
raise ValueError(f"Unknown content type: {type(content)}")
def _convert_message(self, message: Message) -> dict[str, Any]:
"""
Convert a Message to provider format.
Can be overridden for provider-specific handling (e.g., filtering system messages).
"""
return message.to_dict()
def _should_include_message(self, message: Message) -> bool:
"""
Determine if a message should be included in the request.
Override to filter messages (e.g., Anthropic filters SYSTEM messages).
Args:
message: Message to check
Returns:
True if message should be included
"""
return True
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""
Convert a list of messages to provider format.
Uses _should_include_message for filtering and _convert_message for conversion.
"""
return [
self._convert_message(msg)
for msg in messages
if self._should_include_message(msg)
]
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
"""
Convert a single ToolDefinition to provider format.
Default format matches Anthropic. Override for other providers (e.g., OpenAI uses functions).
"""
return {
"name": tool.name,
"description": tool.description,
"input_schema": tool.input_schema,
}
def _convert_tools(
self, tools: list[ToolDefinition] | None
) -> Optional[list[dict[str, Any]]]:
"""Convert tool definitions to provider format."""
if not tools:
return None
return [self._convert_tool(tool) for tool in tools]
@abstractmethod
def generate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
Generate a non-streaming response.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Returns:
Generated text response
"""
pass
@abstractmethod
def stream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""
Generate a streaming response.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Yields:
StreamEvent objects containing text chunks, tool uses, or errors
"""
pass
@abstractmethod
async def agenerate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
Generate a non-streaming response asynchronously.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Returns:
Generated text response
"""
pass
@abstractmethod
async def astream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""
Generate a streaming response asynchronously.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Yields:
StreamEvent objects containing text chunks, tool uses, or errors
"""
pass
@abstractmethod
def stream_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Iterator[StreamEvent]:
pass
def run_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Turn:
thinking, response, tool_calls = "", "", {}
for event in self.stream_with_tools(
messages=messages,
tools=tools,
settings=settings,
system_prompt=system_prompt,
max_iterations=max_iterations,
):
if event.type == "thinking":
thinking += event.data
elif event.type == "tool_use":
tool_calls[event.data["id"]] = {
"name": event.data["name"],
"input": event.data["input"],
"output": "",
}
elif event.type == "text":
response += event.data
elif event.type == "tool_result":
current = tool_calls.get(event.data["tool_use_id"]) or {}
tool_calls[event.data["tool_use_id"]] = {
"name": event.data.get("name") or current.get("name"),
"input": event.data.get("input") or current.get("input"),
"output": event.data.get("content"),
}
return Turn(
thinking=thinking or None,
response=response or None,
tool_calls=tool_calls or None,
)
def create_provider(
model: str | None = None,
api_key: str | None = None,
enable_thinking: bool = False,
) -> BaseLLMProvider:
"""
Create an LLM provider based on the model name.
Args:
model: Model identifier (e.g., "claude-3-opus-20240229", "gpt-4").
If not provided, uses SUMMARIZER_MODEL from settings.
api_key: Optional API key. If not provided, uses keys from settings.
enable_thinking: Enable extended thinking for supported models (Claude Opus 4+, Sonnet 4+, Sonnet 3.7)
Returns:
An initialized LLM provider
Raises:
ValueError: If the provider cannot be determined from the model name
"""
# Use default model from settings if not provided
if model is None:
model = settings.SUMMARIZER_MODEL
provider, model = model.split("/", 1)
if provider == "anthropic":
# Anthropic models
if api_key is None:
api_key = settings.ANTHROPIC_API_KEY
if not api_key:
raise ValueError(
"ANTHROPIC_API_KEY not found in settings. "
"Please set it in your .env file."
)
from memory.common.llms.anthropic_provider import AnthropicProvider
return AnthropicProvider(
api_key=api_key, model=model, enable_thinking=enable_thinking
)
# Could add OpenAI support here in the future
# elif "gpt" in model_lower or model.startswith("openai"):
# ...
else:
raise ValueError(
f"Unknown provider for model: {model}. "
f"Supported providers: Anthropic (claude-*)"
)

View File

@ -0,0 +1,388 @@
"""OpenAI LLM provider implementation."""
import logging
from typing import Any, AsyncIterator, Iterator, Optional
import openai
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
LLMSettings,
Message,
MessageContent,
MessageRole,
StreamEvent,
TextContent,
ThinkingContent,
ToolDefinition,
ToolResultContent,
ToolUseContent,
)
logger = logging.getLogger(__name__)
class OpenAIProvider(BaseLLMProvider):
"""OpenAI LLM provider with streaming and tool support."""
def _initialize_client(self) -> openai.OpenAI:
"""Initialize the OpenAI client."""
return openai.OpenAI(api_key=self.api_key)
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""
Convert our Message format to OpenAI format.
Args:
messages: List of messages in our format
Returns:
List of messages in OpenAI format
"""
openai_messages = []
for msg in messages:
if isinstance(msg.content, str):
openai_messages.append({"role": msg.role.value, "content": msg.content})
else:
# Handle multi-part content
content_parts = []
for item in msg.content:
if isinstance(item, TextContent):
content_parts.append({"type": "text", "text": item.text})
elif isinstance(item, ImageContent):
encoded_image = self.encode_image(item.image)
image_part: dict[str, Any] = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}"
},
}
if item.detail:
image_part["image_url"]["detail"] = item.detail
content_parts.append(image_part)
elif isinstance(item, ToolUseContent):
# OpenAI doesn't have tool_use in content, it's a separate field
# We'll handle this by adding a tool_calls field to the message
pass
elif isinstance(item, ToolResultContent):
# OpenAI handles tool results as separate "tool" role messages
openai_messages.append(
{
"role": "tool",
"tool_call_id": item.tool_use_id,
"content": item.content,
}
)
continue
elif isinstance(item, ThinkingContent):
# OpenAI doesn't have native thinking support in most models
# We can add it as text with a special marker
content_parts.append(
{
"type": "text",
"text": f"[Thinking: {item.thinking}]",
}
)
# Check if this message has tool calls
tool_calls = [
item for item in msg.content if isinstance(item, ToolUseContent)
]
message_dict: dict[str, Any] = {"role": msg.role.value}
if content_parts:
message_dict["content"] = content_parts
if tool_calls:
message_dict["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": str(tc.input)},
}
for tc in tool_calls
]
openai_messages.append(message_dict)
return openai_messages
def _convert_tools(
self, tools: Optional[list[ToolDefinition]]
) -> Optional[list[dict[str, Any]]]:
"""
Convert our tool definitions to OpenAI format.
Args:
tools: List of tool definitions
Returns:
List of tools in OpenAI format
"""
if not tools:
return None
return [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
},
}
for tool in tools
]
def generate(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
response = self.client.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def stream(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
stream = self.client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None
for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
)
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
yield StreamEvent(type="done")
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
yield StreamEvent(type="error", data=str(e))
async def agenerate(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
# Use async client
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
response = await async_client.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
async def astream(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
# Use async client
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
stream = await async_client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
)
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
yield StreamEvent(type="done")
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
yield StreamEvent(type="error", data=str(e))

View File

@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Any, Callable, TypedDict
ToolInput = str | dict[str, Any] | None
ToolHandler = Callable[[ToolInput], str]
class ToolCall(TypedDict):
"""A call to a tool."""
name: str
id: str
input: ToolInput
class ToolResult(TypedDict):
"""A result from a tool call."""
id: str
name: str
input: ToolInput
output: str
@dataclass
class ToolDefinition:
"""Definition of a tool that can be called by the LLM."""
name: str
description: str
input_schema: dict[str, Any] # JSON Schema for the tool's parameters
function: ToolHandler
def __call__(self, input: ToolInput) -> str:
return self.function(input)

View File

@ -0,0 +1,42 @@
"""Ping tool for testing LLM tool integration."""
from memory.common.llms.tools import ToolDefinition, ToolInput
def handle_ping_call(message: ToolInput = None) -> str:
"""
Handle a ping tool call.
Args:
message: Optional message to include in response
Returns:
Response string
"""
if message:
return f"pong: {message}"
return "pong"
def get_ping_tool() -> ToolDefinition:
"""
Get a ping tool definition for testing tool calls.
Returns a simple tool that takes no required parameters and can be used
to verify that tool calling is working correctly.
"""
return ToolDefinition(
name="ping",
description="A simple test tool that returns 'pong'. Use this to verify tool calling is working.",
input_schema={
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Optional message to echo back",
}
},
"required": [],
},
function=handle_ping_call,
)

View File

@ -0,0 +1,9 @@
def process_message(
msg: str,
history: list[str],
model: str | None = None,
system_prompt: str | None = None,
allowed_tools: list[str] | None = None,
disallowed_tools: list[str] | None = None,
) -> str:
return "asd"

View File

@ -172,6 +172,7 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
# Discord collector settings
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)

View File

@ -105,7 +105,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
try:
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
response = llms.summarize(prompt, settings.SUMMARIZER_MODEL)
result = parse_response(response)
summary = result.get("summary", "")

View File

@ -160,7 +160,7 @@ def should_track_message(
return False
if channel.channel_type in ("dm", "group_dm"):
return bool(user.allow_dm_tracking)
return bool(user.track_messages)
# Default: track the message
return True

View File

@ -16,7 +16,11 @@ from memory.workers.tasks.content_processing import (
create_task_result,
process_content_item,
)
from memory.common.celery_app import ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE
from memory.common.celery_app import (
ADD_DISCORD_MESSAGE,
EDIT_DISCORD_MESSAGE,
PROCESS_DISCORD_MESSAGE,
)
from memory.common import settings
from sqlalchemy.orm import Session, scoped_session
@ -40,6 +44,41 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
def should_process(message: DiscordMessage) -> bool:
return (
settings.DISCORD_PROCESS_MESSAGES
and settings.DISCORD_NOTIFICATIONS_ENABLED
and not (
(message.server and message.server.ignore_messages)
or (message.channel and message.channel.ignore_messages)
or (message.discord_user and message.discord_user.ignore_messages)
)
)
@app.task(name=PROCESS_DISCORD_MESSAGE)
@safe_task_execution
def process_discord_message(message_id: int) -> dict[str, Any]:
logger.info(f"Processing Discord message {message_id}")
with make_session() as session:
discord_message = session.query(DiscordMessage).get(message_id)
if not discord_message:
logger.info(f"Discord message not found: {message_id}")
return {
"status": "error",
"error": "Message not found",
"message_id": message_id,
}
print("Processing message", discord_message.id, discord_message.content)
return {
"status": "processed",
"message_id": message_id,
}
@app.task(name=ADD_DISCORD_MESSAGE)
@safe_task_execution
def add_discord_message(
@ -87,11 +126,8 @@ def add_discord_message(
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
result = process_content_item(discord_message, session)
logger.info(
f"Discord message ID after process_content_item: {discord_message.id}"
)
logger.info(f"Process result: {result}")
if should_process(discord_message):
process_discord_message.delay(discord_message.id)
return result

View File

@ -88,11 +88,10 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Make the LLM call
if scheduled_call.model:
response = llms.call(
response = llms.summarize(
prompt=cast(str, scheduled_call.message),
model=cast(str, scheduled_call.model),
system_prompt=cast(str, scheduled_call.system_prompt)
or llms.SYSTEM_PROMPT,
system_prompt=cast(str, scheduled_call.system_prompt),
)
else:
response = cast(str, scheduled_call.message)

View File

@ -273,6 +273,27 @@ def mock_anthropic_client():
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
client = mock_client()
client.messages = Mock()
# Mock stream as a context manager
mock_stream = Mock()
mock_stream.__enter__ = Mock(
return_value=Mock(
__iter__=lambda self: iter(
[
Mock(
type="content_block_delta",
delta=Mock(
type="text_delta",
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
),
)
]
)
)
)
mock_stream.__exit__ = Mock(return_value=False)
client.messages.stream = Mock(return_value=mock_stream)
client.messages.create = Mock(
return_value=Mock(
content=[

View File

@ -2,318 +2,250 @@ import pytest
from unittest.mock import Mock, patch
import requests
from memory.common import discord, settings
from memory.common import discord
@pytest.fixture
def mock_session_request():
with patch("requests.Session.request") as mock:
yield mock
def mock_api_url():
"""Mock the API URL to avoid using actual settings"""
with patch(
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
):
yield
@pytest.fixture
def mock_get_channels_response():
return [
{"name": "memory-errors", "id": "error_channel_id"},
{"name": "memory-activity", "id": "activity_channel_id"},
{"name": "memory-discoveries", "id": "discovery_channel_id"},
{"name": "memory-chat", "id": "chat_channel_id"},
]
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
def test_get_api_url():
"""Test API URL construction"""
assert discord.get_api_url() == "http://testhost:9999"
def test_discord_server_init(mock_session_request, mock_get_channels_response):
# Mock the channels API call
@patch("requests.post")
def test_send_dm_success(mock_post, mock_api_url):
"""Test successful DM sending"""
mock_response = Mock()
mock_response.json.return_value = mock_get_channels_response
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
mock_post.return_value = mock_response
server = discord.DiscordServer("server123", "Test Server")
result = discord.send_dm("user123", "Hello!")
assert server.server_id == "server123"
assert server.server_name == "Test Server"
assert hasattr(server, "channels")
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
# Mock the channels API call
mock_response = Mock()
mock_response.json.return_value = mock_get_channels_response
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer("server123", "Test Server")
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
def test_setup_channels_create_missing(mock_session_request):
# Mock get channels (empty) and create channel calls
get_response = Mock()
get_response.json.return_value = []
get_response.raise_for_status.return_value = None
create_response = Mock()
create_response.json.return_value = {"id": "new_channel_id"}
create_response.raise_for_status.return_value = None
mock_session_request.side_effect = [
get_response,
create_response,
create_response,
create_response,
create_response,
]
server = discord.DiscordServer("server123", "Test Server")
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
def test_channel_properties():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {
discord.ERROR_CHANNEL: "error_id",
discord.ACTIVITY_CHANNEL: "activity_id",
discord.DISCOVERY_CHANNEL: "discovery_id",
discord.CHAT_CHANNEL: "chat_id",
}
assert server.error_channel == "error_id"
assert server.activity_channel == "activity_id"
assert server.discovery_channel == "discovery_id"
assert server.chat_channel == "chat_id"
def test_channel_id_exists():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {"test-channel": "channel123"}
assert server.channel_id("test-channel") == "channel123"
def test_channel_id_not_found():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {}
with pytest.raises(ValueError, match="Channel nonexistent not found"):
server.channel_id("nonexistent")
def test_send_message(mock_session_request):
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.send_message("channel123", "Hello World")
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/channels/channel123/messages",
data=None,
json={"content": "Hello World"},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
json={"user_identifier": "user123", "message": "Hello!"},
timeout=10,
)
def test_create_channel(mock_session_request):
@patch("requests.post")
def test_send_dm_api_failure(mock_post, mock_api_url):
"""Test DM sending when API returns failure"""
mock_response = Mock()
mock_response.json.return_value = {"id": "new_channel_id"}
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
mock_post.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
result = discord.send_dm("user123", "Hello!")
channel_id = server.create_channel("new-channel")
assert channel_id == "new_channel_id"
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/guilds/server123/channels",
data=None,
json={"name": "new-channel", "type": 0},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
)
assert result is False
def test_create_channel_custom_type(mock_session_request):
@patch("requests.post")
def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_send_dm_http_error(mock_post, mock_api_url):
"""Test DM sending when HTTP error occurs"""
mock_response = Mock()
mock_response.json.return_value = {"id": "voice_channel_id"}
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_broadcast_message_success(mock_post, mock_api_url):
"""Test successful channel message broadcast"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
mock_post.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
result = discord.broadcast_message("general", "Announcement!")
channel_id = server.create_channel("voice-channel", channel_type=2)
assert channel_id == "voice_channel_id"
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/guilds/server123/channels",
data=None,
json={"name": "voice-channel", "type": 2},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"},
timeout=10,
)
def test_str_representation():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
server.server_name = "Test Server"
@patch("requests.post")
def test_broadcast_message_failure(mock_post, mock_api_url):
"""Test channel message broadcast failure"""
mock_response = Mock()
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
result = discord.broadcast_message("general", "Announcement!")
assert result is False
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
def test_request_adds_headers(mock_session_request):
server = discord.DiscordServer.__new__(discord.DiscordServer)
@patch("requests.post")
def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout")
server.request("GET", "https://example.com", headers={"Custom": "header"})
result = discord.broadcast_message("general", "Announcement!")
expected_headers = {
"Custom": "header",
"Authorization": "Bot test_token_123",
"Content-Type": "application/json",
}
mock_session_request.assert_called_once_with(
"GET", "https://example.com", headers=expected_headers
)
assert result is False
def test_channels_url():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
assert (
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
)
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
def test_get_bot_servers_success(mock_get):
def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy"""
mock_response = Mock()
mock_response.json.return_value = [
{"id": "server1", "name": "Server 1"},
{"id": "server2", "name": "Server 2"},
]
mock_response.json.return_value = {"status": "healthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
servers = discord.get_bot_servers()
result = discord.is_collector_healthy()
assert len(servers) == 2
assert servers[0] == {"id": "server1", "name": "Server 1"}
mock_get.assert_called_once_with(
"https://discord.com/api/v10/users/@me/guilds",
headers={"Authorization": "Bot test_token"},
)
assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
def test_get_bot_servers_no_token():
assert discord.get_bot_servers() == []
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
def test_get_bot_servers_exception(mock_get):
mock_get.side_effect = requests.RequestException("API Error")
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status"""
mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
servers = discord.get_bot_servers()
result = discord.is_collector_healthy()
assert servers == []
assert result is False
@patch("memory.common.discord.get_bot_servers")
@patch("memory.common.discord.DiscordServer")
def test_load_servers(mock_discord_server_class, mock_get_servers):
mock_get_servers.return_value = [
{"id": "server1", "name": "Server 1"},
{"id": "server2", "name": "Server 2"},
]
@patch("requests.get")
def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused")
discord.load_servers()
result = discord.is_collector_healthy()
assert mock_discord_server_class.call_count == 2
mock_discord_server_class.assert_any_call("server1", "Server 1")
mock_discord_server_class.assert_any_call("server2", "Server 2")
assert result is False
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_broadcast_message():
mock_server1 = Mock()
mock_server2 = Mock()
discord.servers = {"1": mock_server1, "2": mock_server2}
@patch("requests.post")
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
"""Test successful metadata refresh"""
mock_response = Mock()
mock_response.json.return_value = {
"servers": 5,
"channels": 20,
"users": 100,
}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
discord.broadcast_message("test-channel", "Hello")
result = discord.refresh_discord_metadata()
mock_server1.send_message.assert_called_once_with(
mock_server1.channel_id.return_value, "Hello"
)
mock_server2.send_message.assert_called_once_with(
mock_server2.channel_id.return_value, "Hello"
assert result == {"servers": 5, "channels": 20, "users": 100}
mock_post.assert_called_once_with(
"http://localhost:8000/refresh_metadata", timeout=30
)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_broadcast_message_disabled():
mock_server = Mock()
discord.servers = {"1": mock_server}
@patch("requests.post")
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
"""Test metadata refresh failure"""
mock_post.side_effect = requests.RequestException("Failed to connect")
discord.broadcast_message("test-channel", "Hello")
result = discord.refresh_discord_metadata()
mock_server.send_message.assert_not_called()
assert result is None
@patch("requests.post")
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
"""Test metadata refresh with HTTP error"""
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
mock_post.return_value = mock_response
result = discord.refresh_discord_metadata()
assert result is None
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
def test_send_error_message(mock_broadcast):
discord.send_error_message("Error occurred")
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
"""Test sending error message to error channel"""
mock_broadcast.return_value = True
result = discord.send_error_message("Something broke")
assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
def test_send_activity_message(mock_broadcast):
discord.send_activity_message("Activity update")
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
"""Test sending activity message to activity channel"""
mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in")
assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
def test_send_discovery_message(mock_broadcast):
discord.send_discovery_message("Discovery made")
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
"""Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern")
assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
def test_send_chat_message(mock_broadcast):
discord.send_chat_message("Chat message")
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
"""Test sending chat message to chat channel"""
mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot")
assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong")
mock_send_error.assert_called_once()
@ -323,69 +255,181 @@ def test_notify_task_failure_basic(mock_send_error):
assert "**Error:** Something went wrong" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_args(mock_send_error):
"""Test task failure notification with arguments"""
discord.notify_task_failure(
"test_task",
"Error message",
task_args=("arg1", "arg2"),
task_kwargs={"key": "value"},
"Error occurred",
task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123},
)
message = mock_send_error.call_args[0][0]
assert "**Args:** `('arg1', 'arg2')`" in message
assert "**Kwargs:** `{'key': 'value'}`" in message
assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_traceback(mock_send_error):
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
"""Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
message = mock_send_error.call_args[0][0]
assert "**Traceback:**" in message
assert traceback in message
assert "Exception: test" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_error(mock_send_error):
long_error = "x" * 600 # Longer than 500 char limit
"""Test that long error messages are truncated"""
long_error = "x" * 600
discord.notify_task_failure("test_task", long_error)
message = mock_send_error.call_args[0][0]
assert long_error[:500] in message
# Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message
# The full 600-char error should not be present
error_section = message.split("**Error:** ")[1].split("\n")[0]
assert len(error_section) == 500
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
long_traceback = "x" * 1000 # Longer than 800 char limit
"""Test that long tracebacks are truncated"""
long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
message = mock_send_error.call_args[0][0]
# Traceback should show last 800 chars
assert long_traceback[-800:] in message
# The full 1000-char traceback should not be present
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
assert len(traceback_section) == 800
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated"""
long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args)
message = mock_send_error.call_args[0][0]
# Args should be truncated to 200 chars
assert (
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
) # Some buffer for formatting
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
message = mock_send_error.call_args[0][0]
# Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error):
discord.notify_task_failure("test_task", "Error message")
"""Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_not_called()
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_send_fails(mock_send_error):
mock_send_error.side_effect = Exception("Discord API error")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_send_error_exception(mock_send_error):
"""Test that exceptions in send_error_message don't propagate"""
mock_send_error.side_effect = Exception("Failed to send")
# Should not raise, just log the error
discord.notify_task_failure("test_task", "Error message")
# Should not raise
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_called_once()
@pytest.mark.parametrize(
"function,channel_setting,message",
[
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
],
)
@patch("memory.common.discord.broadcast_message")
def test_convenience_functions_use_correct_channels(
mock_broadcast, function, channel_setting, message
):
"""Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message)
mock_broadcast.assert_called_once_with("test-channel", message)
@patch("requests.post")
def test_send_dm_with_special_characters(mock_post, mock_api_url):
"""Test sending DM with special characters"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars
@patch("requests.post")
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
"""Test broadcasting a long message"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
long_message = "A" * 2000
result = discord.broadcast_message("general", long_message)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message
@patch("requests.get")
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
"""Test health check when response doesn't have status key"""
mock_response = Mock()
mock_response.json.return_value = {}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is False

View File

@ -0,0 +1,435 @@
import pytest
from unittest.mock import Mock, patch
import requests
from memory.common import discord
@pytest.fixture
def mock_api_url():
"""Mock the API URL to avoid using actual settings"""
with patch(
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
):
yield
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
def test_get_api_url():
"""Test API URL construction"""
assert discord.get_api_url() == "http://testhost:9999"
@patch("requests.post")
def test_send_dm_success(mock_post, mock_api_url):
"""Test successful DM sending"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
json={"user_identifier": "user123", "message": "Hello!"},
timeout=10,
)
@patch("requests.post")
def test_send_dm_api_failure(mock_post, mock_api_url):
"""Test DM sending when API returns failure"""
mock_response = Mock()
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_send_dm_http_error(mock_post, mock_api_url):
"""Test DM sending when HTTP error occurs"""
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_broadcast_message_success(mock_post, mock_api_url):
"""Test successful channel message broadcast"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"},
timeout=10,
)
@patch("requests.post")
def test_broadcast_message_failure(mock_post, mock_api_url):
"""Test channel message broadcast failure"""
mock_response = Mock()
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
assert result is False
@patch("requests.post")
def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!")
assert result is False
@patch("requests.get")
def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy"""
mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@patch("requests.get")
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status"""
mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is False
@patch("requests.get")
def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy()
assert result is False
@patch("requests.post")
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
"""Test successful metadata refresh"""
mock_response = Mock()
mock_response.json.return_value = {
"servers": 5,
"channels": 20,
"users": 100,
}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.refresh_discord_metadata()
assert result == {"servers": 5, "channels": 20, "users": 100}
mock_post.assert_called_once_with(
"http://localhost:8000/refresh_metadata", timeout=30
)
@patch("requests.post")
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
"""Test metadata refresh failure"""
mock_post.side_effect = requests.RequestException("Failed to connect")
result = discord.refresh_discord_metadata()
assert result is None
@patch("requests.post")
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
"""Test metadata refresh with HTTP error"""
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
mock_post.return_value = mock_response
result = discord.refresh_discord_metadata()
assert result is None
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel"""
mock_broadcast.return_value = True
result = discord.send_error_message("Something broke")
assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel"""
mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in")
assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern")
assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel"""
mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot")
assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong")
mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0]
assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_args(mock_send_error):
"""Test task failure notification with arguments"""
discord.notify_task_failure(
"test_task",
"Error occurred",
task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123},
)
message = mock_send_error.call_args[0][0]
assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
message = mock_send_error.call_args[0][0]
assert "**Traceback:**" in message
assert "Exception: test" in message
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated"""
long_error = "x" * 600
discord.notify_task_failure("test_task", long_error)
message = mock_send_error.call_args[0][0]
# Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message
# The full 600-char error should not be present
error_section = message.split("**Error:** ")[1].split("\n")[0]
assert len(error_section) == 500
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated"""
long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
message = mock_send_error.call_args[0][0]
# Traceback should show last 800 chars
assert long_traceback[-800:] in message
# The full 1000-char traceback should not be present
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
assert len(traceback_section) == 800
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated"""
long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args)
message = mock_send_error.call_args[0][0]
# Args should be truncated to 200 chars
assert (
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
) # Some buffer for formatting
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
message = mock_send_error.call_args[0][0]
# Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_not_called()
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_send_error_exception(mock_send_error):
"""Test that exceptions in send_error_message don't propagate"""
mock_send_error.side_effect = Exception("Failed to send")
# Should not raise
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_called_once()
@pytest.mark.parametrize(
"function,channel_setting,message",
[
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
],
)
@patch("memory.common.discord.broadcast_message")
def test_convenience_functions_use_correct_channels(
mock_broadcast, function, channel_setting, message
):
"""Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message)
mock_broadcast.assert_called_once_with("test-channel", message)
@patch("requests.post")
def test_send_dm_with_special_characters(mock_post, mock_api_url):
"""Test sending DM with special characters"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars
@patch("requests.post")
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
"""Test broadcasting a long message"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
long_message = "A" * 2000
result = discord.broadcast_message("general", long_message)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message
@patch("requests.get")
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
"""Test health check when response doesn't have status key"""
mock_response = Mock()
mock_response.json.return_value = {}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is False

View File

@ -0,0 +1,840 @@
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, patch, AsyncMock, MagicMock
import discord
from memory.discord.collector import (
create_or_update_server,
determine_channel_metadata,
create_or_update_channel,
create_or_update_user,
determine_message_metadata,
should_track_message,
should_collect_bot_message,
sync_guild_metadata,
MessageCollector,
)
from memory.common.db.models.sources import (
DiscordServer,
DiscordChannel,
DiscordUser,
)
# Fixtures for Discord objects
@pytest.fixture
def mock_guild():
"""Mock Discord Guild object"""
guild = Mock(spec=discord.Guild)
guild.id = 123456789
guild.name = "Test Server"
guild.description = "A test server"
guild.member_count = 42
return guild
@pytest.fixture
def mock_text_channel():
"""Mock Discord TextChannel object"""
channel = Mock(spec=discord.TextChannel)
channel.id = 987654321
channel.name = "general"
guild = Mock()
guild.id = 123456789
channel.guild = guild
return channel
@pytest.fixture
def mock_dm_channel():
"""Mock Discord DMChannel object"""
channel = Mock(spec=discord.DMChannel)
channel.id = 111222333
recipient = Mock()
recipient.name = "TestUser"
channel.recipient = recipient
return channel
@pytest.fixture
def mock_user():
"""Mock Discord User object"""
user = Mock(spec=discord.User)
user.id = 444555666
user.name = "testuser"
user.display_name = "Test User"
user.bot = False
return user
@pytest.fixture
def mock_message(mock_text_channel, mock_user):
"""Mock Discord Message object"""
message = Mock(spec=discord.Message)
message.id = 777888999
message.channel = mock_text_channel
message.author = mock_user
message.guild = mock_text_channel.guild
message.content = "Test message"
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
message.reference = None
return message
# Tests for create_or_update_server
def test_create_or_update_server_creates_new(db_session, mock_guild):
"""Test creating a new server record"""
result = create_or_update_server(db_session, mock_guild)
assert result is not None
assert result.id == mock_guild.id
assert result.name == mock_guild.name
assert result.description == mock_guild.description
assert result.member_count == mock_guild.member_count
def test_create_or_update_server_updates_existing(db_session, mock_guild):
"""Test updating an existing server record"""
# Create initial server
server = DiscordServer(
id=mock_guild.id,
name="Old Name",
description="Old Description",
member_count=10,
)
db_session.add(server)
db_session.commit()
# Update with new data
mock_guild.name = "New Name"
mock_guild.description = "New Description"
mock_guild.member_count = 50
result = create_or_update_server(db_session, mock_guild)
assert result.name == "New Name"
assert result.description == "New Description"
assert result.member_count == 50
assert result.last_sync_at is not None
def test_create_or_update_server_none_guild(db_session):
"""Test with None guild"""
result = create_or_update_server(db_session, None)
assert result is None
# Tests for determine_channel_metadata
def test_determine_channel_metadata_dm():
"""Test metadata for DM channel"""
channel = Mock(spec=discord.DMChannel)
channel.recipient = Mock()
channel.recipient.name = "TestUser"
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "dm"
assert server_id is None
assert "DM with TestUser" in name
def test_determine_channel_metadata_dm_no_recipient():
"""Test metadata for DM channel without recipient"""
channel = Mock(spec=discord.DMChannel)
channel.recipient = None
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "dm"
assert name == "Unknown DM"
def test_determine_channel_metadata_group_dm():
"""Test metadata for group DM channel"""
channel = Mock(spec=discord.GroupChannel)
channel.name = "Group Chat"
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "group_dm"
assert server_id is None
assert name == "Group Chat"
def test_determine_channel_metadata_group_dm_no_name():
"""Test metadata for group DM without name"""
channel = Mock(spec=discord.GroupChannel)
channel.name = None
channel_type, server_id, name = determine_channel_metadata(channel)
assert name == "Group DM"
def test_determine_channel_metadata_text_channel():
"""Test metadata for text channel"""
channel = Mock(spec=discord.TextChannel)
channel.name = "general"
channel.guild = Mock()
channel.guild.id = 123
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "text"
assert server_id == 123
assert name == "general"
def test_determine_channel_metadata_voice_channel():
"""Test metadata for voice channel"""
channel = Mock(spec=discord.VoiceChannel)
channel.name = "voice-chat"
channel.guild = Mock()
channel.guild.id = 456
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "voice"
assert server_id == 456
assert name == "voice-chat"
def test_determine_channel_metadata_thread():
"""Test metadata for thread"""
channel = Mock(spec=discord.Thread)
channel.name = "thread-1"
channel.guild = Mock()
channel.guild.id = 789
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "thread"
assert server_id == 789
assert name == "thread-1"
def test_determine_channel_metadata_unknown():
"""Test metadata for unknown channel type"""
channel = Mock()
channel.id = 999
# Ensure the mock doesn't have a 'name' attribute
del channel.name
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "unknown"
assert name == "Unknown-999"
# Tests for create_or_update_channel
def test_create_or_update_channel_creates_new(
db_session, mock_text_channel, mock_guild
):
"""Test creating a new channel record"""
# Create the server first to satisfy foreign key constraint
create_or_update_server(db_session, mock_guild)
result = create_or_update_channel(db_session, mock_text_channel)
assert result is not None
assert result.id == mock_text_channel.id
assert result.name == mock_text_channel.name
assert result.channel_type == "text"
def test_create_or_update_channel_updates_existing(db_session, mock_text_channel):
"""Test updating an existing channel record"""
# Create initial channel
channel = DiscordChannel(
id=mock_text_channel.id,
name="old-name",
channel_type="text",
)
db_session.add(channel)
db_session.commit()
# Update with new name
mock_text_channel.name = "new-name"
result = create_or_update_channel(db_session, mock_text_channel)
assert result.name == "new-name"
def test_create_or_update_channel_none_channel(db_session):
"""Test with None channel"""
result = create_or_update_channel(db_session, None)
assert result is None
# Tests for create_or_update_user
def test_create_or_update_user_creates_new(db_session, mock_user):
"""Test creating a new user record"""
result = create_or_update_user(db_session, mock_user)
assert result is not None
assert result.id == mock_user.id
assert result.username == mock_user.name
assert result.display_name == mock_user.display_name
def test_create_or_update_user_updates_existing(db_session, mock_user):
"""Test updating an existing user record"""
# Create initial user
user = DiscordUser(
id=mock_user.id,
username="oldname",
display_name="Old Display Name",
)
db_session.add(user)
db_session.commit()
# Update with new data
mock_user.name = "newname"
mock_user.display_name = "New Display Name"
result = create_or_update_user(db_session, mock_user)
assert result.username == "newname"
assert result.display_name == "New Display Name"
def test_create_or_update_user_none_user(db_session):
"""Test with None user"""
result = create_or_update_user(db_session, None)
assert result is None
# Tests for determine_message_metadata
def test_determine_message_metadata_default():
"""Test metadata for default message"""
message = Mock()
message.reference = None
message.channel = Mock()
# Ensure channel doesn't have parent attribute
del message.channel.parent
message_type, reply_to_id, thread_id = determine_message_metadata(message)
assert message_type == "default"
assert reply_to_id is None
assert thread_id is None
def test_determine_message_metadata_reply():
"""Test metadata for reply message"""
message = Mock()
message.reference = Mock()
message.reference.message_id = 123456
message.channel = Mock()
message_type, reply_to_id, thread_id = determine_message_metadata(message)
assert message_type == "reply"
assert reply_to_id == 123456
def test_determine_message_metadata_thread():
"""Test metadata for message in thread"""
message = Mock()
message.reference = None
message.channel = Mock()
message.channel.id = 999
message.channel.parent = Mock() # Has parent means it's a thread
message_type, reply_to_id, thread_id = determine_message_metadata(message)
assert thread_id == 999
# Tests for should_track_message
def test_should_track_message_server_disabled(db_session):
"""Test when server has tracking disabled"""
server = DiscordServer(id=1, name="Server", track_messages=False)
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
user = DiscordUser(id=3, username="User")
result = should_track_message(server, channel, user)
assert result is False
def test_should_track_message_channel_disabled(db_session):
"""Test when channel has tracking disabled"""
server = DiscordServer(id=1, name="Server", track_messages=True)
channel = DiscordChannel(
id=2, name="Channel", channel_type="text", track_messages=False
)
user = DiscordUser(id=3, username="User")
result = should_track_message(server, channel, user)
assert result is False
def test_should_track_message_dm_allowed(db_session):
"""Test DM tracking when user allows it"""
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
user = DiscordUser(id=3, username="User", track_messages=True)
result = should_track_message(None, channel, user)
assert result is True
def test_should_track_message_dm_not_allowed(db_session):
"""Test DM tracking when user doesn't allow it"""
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
user = DiscordUser(id=3, username="User", track_messages=False)
result = should_track_message(None, channel, user)
assert result is False
def test_should_track_message_default_true(db_session):
"""Test default tracking behavior"""
server = DiscordServer(id=1, name="Server", track_messages=True)
channel = DiscordChannel(
id=2, name="Channel", channel_type="text", track_messages=True
)
user = DiscordUser(id=3, username="User")
result = should_track_message(server, channel, user)
assert result is True
# Tests for should_collect_bot_message
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", False)
def test_should_collect_bot_message_bot_not_allowed():
"""Test bot message collection when disabled"""
message = Mock()
message.author = Mock()
message.author.bot = True
result = should_collect_bot_message(message)
assert result is False
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", True)
def test_should_collect_bot_message_bot_allowed():
"""Test bot message collection when enabled"""
message = Mock()
message.author = Mock()
message.author.bot = True
result = should_collect_bot_message(message)
assert result is True
def test_should_collect_bot_message_human():
"""Test human message collection"""
message = Mock()
message.author = Mock()
message.author.bot = False
result = should_collect_bot_message(message)
assert result is True
# Tests for sync_guild_metadata
@patch("memory.discord.collector.make_session")
def test_sync_guild_metadata(mock_make_session, mock_guild):
"""Test syncing guild metadata"""
mock_session = Mock()
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
mock_make_session.return_value.__exit__ = Mock(return_value=None)
# Mock session.query().get() to return None (new server)
mock_session.query.return_value.get.return_value = None
# Mock channels
text_channel = Mock(spec=discord.TextChannel)
text_channel.id = 1
text_channel.name = "general"
text_channel.guild = mock_guild
voice_channel = Mock(spec=discord.VoiceChannel)
voice_channel.id = 2
voice_channel.name = "voice"
voice_channel.guild = mock_guild
mock_guild.channels = [text_channel, voice_channel]
sync_guild_metadata(mock_guild)
# Verify session.commit was called
mock_session.commit.assert_called_once()
# Tests for MessageCollector class
def test_message_collector_init():
"""Test MessageCollector initialization"""
collector = MessageCollector()
assert collector.command_prefix == "!memory_"
assert collector.help_command is None
assert collector.intents.message_content is True
assert collector.intents.guilds is True
assert collector.intents.members is True
assert collector.intents.dm_messages is True
@pytest.mark.asyncio
async def test_on_ready():
"""Test on_ready event handler"""
collector = MessageCollector()
collector.user = Mock()
collector.user.name = "TestBot"
collector.guilds = [Mock(), Mock()]
collector.sync_servers_and_channels = AsyncMock()
await collector.on_ready()
collector.sync_servers_and_channels.assert_called_once()
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
@patch("memory.discord.collector.add_discord_message")
async def test_on_message_success(mock_add_task, mock_make_session, mock_message):
"""Test successful message handling"""
mock_session = Mock()
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
mock_make_session.return_value.__exit__ = Mock(return_value=None)
mock_session.query.return_value.get.return_value = None # New entities
collector = MessageCollector()
await collector.on_message(mock_message)
# Verify task was queued
mock_add_task.delay.assert_called_once()
call_kwargs = mock_add_task.delay.call_args[1]
assert call_kwargs["message_id"] == mock_message.id
assert call_kwargs["channel_id"] == mock_message.channel.id
assert call_kwargs["author_id"] == mock_message.author.id
assert call_kwargs["content"] == mock_message.content
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
async def test_on_message_bot_message_filtered(mock_make_session, mock_message):
"""Test bot message filtering"""
mock_message.author.bot = True
with patch(
"memory.discord.collector.should_collect_bot_message", return_value=False
):
collector = MessageCollector()
await collector.on_message(mock_message)
# Should not create session or queue task
mock_make_session.assert_not_called()
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
async def test_on_message_error_handling(mock_make_session, mock_message):
"""Test error handling in on_message"""
mock_make_session.side_effect = Exception("Database error")
collector = MessageCollector()
# Should not raise
await collector.on_message(mock_message)
@pytest.mark.asyncio
@patch("memory.discord.collector.edit_discord_message")
async def test_on_message_edit(mock_edit_task):
"""Test message edit handler"""
before = Mock()
after = Mock()
after.id = 123
after.content = "Edited content"
after.edited_at = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
collector = MessageCollector()
await collector.on_message_edit(before, after)
mock_edit_task.delay.assert_called_once()
call_kwargs = mock_edit_task.delay.call_args[1]
assert call_kwargs["message_id"] == 123
assert call_kwargs["content"] == "Edited content"
@pytest.mark.asyncio
async def test_on_message_edit_error_handling():
"""Test error handling in on_message_edit"""
before = Mock()
after = Mock()
after.id = 123
after.content = "Edited"
after.edited_at = None # Will trigger datetime.now
with patch("memory.discord.collector.edit_discord_message") as mock_edit:
mock_edit.delay.side_effect = Exception("Task error")
collector = MessageCollector()
# Should not raise
await collector.on_message_edit(before, after)
@pytest.mark.asyncio
async def test_sync_servers_and_channels():
"""Test syncing servers and channels"""
guild1 = Mock()
guild2 = Mock()
collector = MessageCollector()
collector.guilds = [guild1, guild2]
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
await collector.sync_servers_and_channels()
assert mock_sync.call_count == 2
mock_sync.assert_any_call(guild1)
mock_sync.assert_any_call(guild2)
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
async def test_refresh_metadata(mock_make_session):
"""Test metadata refresh"""
mock_session = Mock()
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
mock_make_session.return_value.__exit__ = Mock(return_value=None)
mock_session.query.return_value.get.return_value = None
guild = Mock()
guild.id = 123
guild.name = "Test"
guild.channels = []
guild.members = []
collector = MessageCollector()
collector.guilds = [guild]
collector.intents = Mock()
collector.intents.members = False
result = await collector.refresh_metadata()
assert result["servers_updated"] == 1
assert result["channels_updated"] == 0
assert result["users_updated"] == 0
@pytest.mark.asyncio
async def test_get_user_by_id():
"""Test getting user by ID"""
user = Mock()
user.id = 123
collector = MessageCollector()
collector.get_user = Mock(return_value=user)
result = await collector.get_user(123)
assert result == user
@pytest.mark.asyncio
async def test_get_user_by_username():
"""Test getting user by username"""
member = Mock()
member.name = "testuser"
member.display_name = "Test User"
member.discriminator = "1234"
guild = Mock()
guild.members = [member]
collector = MessageCollector()
collector.guilds = [guild]
result = await collector.get_user("testuser")
assert result == member
@pytest.mark.asyncio
async def test_get_user_not_found():
"""Test getting non-existent user"""
collector = MessageCollector()
collector.guilds = []
with patch.object(collector, "get_user", return_value=None):
with patch.object(
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
):
result = await collector.get_user(999)
assert result is None
@pytest.mark.asyncio
async def test_get_channel_by_name():
"""Test getting channel by name"""
channel = Mock(spec=discord.TextChannel)
channel.name = "general"
guild = Mock()
guild.channels = [channel]
collector = MessageCollector()
collector.guilds = [guild]
result = await collector.get_channel_by_name("general")
assert result == channel
@pytest.mark.asyncio
async def test_get_channel_by_name_not_found():
"""Test getting non-existent channel"""
guild = Mock()
guild.channels = []
collector = MessageCollector()
collector.guilds = [guild]
result = await collector.get_channel_by_name("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_create_channel():
"""Test creating a channel"""
guild = Mock()
guild.name = "Test Server"
new_channel = Mock()
guild.create_text_channel = AsyncMock(return_value=new_channel)
collector = MessageCollector()
collector.get_guild = Mock(return_value=guild)
result = await collector.create_channel("new-channel", guild_id=123)
assert result == new_channel
guild.create_text_channel.assert_called_once_with("new-channel")
@pytest.mark.asyncio
async def test_create_channel_no_guild():
"""Test creating channel when no guild available"""
collector = MessageCollector()
collector.get_guild = Mock(return_value=None)
collector.guilds = []
result = await collector.create_channel("new-channel")
assert result is None
@pytest.mark.asyncio
async def test_send_dm_success():
"""Test sending DM successfully"""
user = Mock()
user.send = AsyncMock()
collector = MessageCollector()
collector.get_user = AsyncMock(return_value=user)
result = await collector.send_dm(123, "Hello!")
assert result is True
user.send.assert_called_once_with("Hello!")
@pytest.mark.asyncio
async def test_send_dm_user_not_found():
"""Test sending DM when user not found"""
collector = MessageCollector()
collector.get_user = AsyncMock(return_value=None)
result = await collector.send_dm(123, "Hello!")
assert result is False
@pytest.mark.asyncio
async def test_send_dm_exception():
"""Test sending DM with exception"""
user = Mock()
user.send = AsyncMock(side_effect=Exception("Send failed"))
collector = MessageCollector()
collector.get_user = AsyncMock(return_value=user)
result = await collector.send_dm(123, "Hello!")
assert result is False
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
async def test_send_to_channel_success():
"""Test sending to channel successfully"""
channel = Mock()
channel.send = AsyncMock()
collector = MessageCollector()
collector.get_channel_by_name = AsyncMock(return_value=channel)
result = await collector.send_to_channel("general", "Announcement!")
assert result is True
channel.send.assert_called_once_with("Announcement!")
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
async def test_send_to_channel_notifications_disabled():
"""Test sending to channel when notifications disabled"""
collector = MessageCollector()
result = await collector.send_to_channel("general", "Announcement!")
assert result is False
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
async def test_send_to_channel_not_found():
"""Test sending to non-existent channel"""
collector = MessageCollector()
collector.get_channel_by_name = AsyncMock(return_value=None)
result = await collector.send_to_channel("nonexistent", "Message")
assert result is False
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
async def test_run_collector():
"""Test running the collector"""
from memory.discord.collector import run_collector
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
mock_collector = Mock()
mock_collector.start = AsyncMock()
mock_collector_class.return_value = mock_collector
await run_collector()
mock_collector.start.assert_called_once_with("test_token")
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
async def test_run_collector_no_token():
"""Test running collector without token"""
from memory.discord.collector import run_collector
# Should return early without raising
await run_collector()

View File

@ -0,0 +1,607 @@
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, patch
from memory.common.db.models import (
DiscordMessage,
DiscordUser,
DiscordServer,
DiscordChannel,
)
from memory.workers.tasks import discord
@pytest.fixture
def mock_discord_user(db_session):
"""Create a Discord user for testing."""
user = DiscordUser(
id=123456789,
username="testuser",
ignore_messages=False,
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def mock_discord_server(db_session):
"""Create a Discord server for testing."""
server = DiscordServer(
id=987654321,
name="Test Server",
ignore_messages=False,
)
db_session.add(server)
db_session.commit()
return server
@pytest.fixture
def mock_discord_channel(db_session, mock_discord_server):
"""Create a Discord channel for testing."""
channel = DiscordChannel(
id=111222333,
name="test-channel",
channel_type="text",
server_id=mock_discord_server.id,
ignore_messages=False,
)
db_session.add(channel)
db_session.commit()
return channel
@pytest.fixture
def sample_message_data(mock_discord_user, mock_discord_channel):
"""Sample message data for testing."""
return {
"message_id": 999888777,
"channel_id": mock_discord_channel.id,
"author_id": mock_discord_user.id,
"content": "This is a test Discord message with enough content to be processed.",
"sent_at": "2024-01-01T12:00:00Z",
"server_id": None,
"message_reference_id": None,
}
def test_get_prev_returns_previous_messages(
db_session, mock_discord_user, mock_discord_channel
):
"""Test that get_prev returns previous messages in order."""
# Create previous messages
msg1 = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content="First message",
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash1" + bytes(26),
)
msg2 = DiscordMessage(
message_id=2,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content="Second message",
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash2" + bytes(26),
)
msg3 = DiscordMessage(
message_id=3,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content="Third message",
sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash3" + bytes(26),
)
db_session.add_all([msg1, msg2, msg3])
db_session.commit()
# Get previous messages before 10:15
result = discord.get_prev(
db_session,
mock_discord_channel.id,
datetime(2024, 1, 1, 10, 15, 0, tzinfo=timezone.utc),
)
assert len(result) == 3
assert result[0] == "testuser: First message"
assert result[1] == "testuser: Second message"
assert result[2] == "testuser: Third message"
def test_get_prev_limits_context_window(
db_session, mock_discord_user, mock_discord_channel
):
"""Test that get_prev respects DISCORD_CONTEXT_WINDOW setting."""
# Create 15 messages (more than the default context window of 10)
for i in range(15):
msg = DiscordMessage(
message_id=i,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content=f"Message {i}",
sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
modality="text",
sha256=f"hash{i}".encode() + bytes(27),
)
db_session.add(msg)
db_session.commit()
result = discord.get_prev(
db_session,
mock_discord_channel.id,
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
)
# Should only return last 10 messages
assert len(result) == 10
assert result[0] == "testuser: Message 5" # Oldest in window
assert result[-1] == "testuser: Message 14" # Most recent
def test_get_prev_empty_channel(db_session, mock_discord_channel):
"""Test get_prev with no previous messages."""
result = discord.get_prev(
db_session,
mock_discord_channel.id,
datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
)
assert result == []
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_normal_message(
db_session, mock_discord_user, mock_discord_server, mock_discord_channel
):
"""Test should_process returns True for normal messages."""
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is True
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
def test_should_process_disabled():
"""Test should_process returns False when processing is disabled."""
message = Mock()
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_should_process_notifications_disabled():
"""Test should_process returns False when notifications are disabled."""
message = Mock()
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_server_ignored(
db_session, mock_discord_user, mock_discord_channel
):
"""Test should_process returns False when server has ignore_messages=True."""
server = DiscordServer(
id=123,
name="Ignored Server",
ignore_messages=True,
)
db_session.add(server)
db_session.commit()
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
server_id=server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_channel_ignored(
db_session, mock_discord_user, mock_discord_server
):
"""Test should_process returns False when channel has ignore_messages=True."""
channel = DiscordChannel(
id=456,
name="ignored-channel",
channel_type="text",
server_id=mock_discord_server.id,
ignore_messages=True,
)
db_session.add(channel)
db_session.commit()
message = DiscordMessage(
message_id=1,
channel_id=channel.id,
discord_user_id=mock_discord_user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_user_ignored(
db_session, mock_discord_server, mock_discord_channel
):
"""Test should_process returns False when user has ignore_messages=True."""
user = DiscordUser(
id=789,
username="ignoreduser",
ignore_messages=True,
)
db_session.add(user)
db_session.commit()
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is False
def test_add_discord_message_success(db_session, sample_message_data, qdrant):
"""Test successful Discord message addition."""
result = discord.add_discord_message(**sample_message_data)
assert result["status"] == "processed"
assert "discordmessage_id" in result
# Verify the message was created in the database
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message is not None
assert message.content == sample_message_data["content"]
assert message.message_type == "default"
assert message.reply_to_message_id is None
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
"""Test adding a Discord message that is a reply."""
sample_message_data["message_reference_id"] = 111222333
discord.add_discord_message(**sample_message_data)
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message.message_type == "reply"
assert message.reply_to_message_id == 111222333
def test_add_discord_message_already_exists(db_session, sample_message_data, qdrant):
"""Test adding a message that already exists."""
# Add the message once
discord.add_discord_message(**sample_message_data)
# Try to add it again
result = discord.add_discord_message(**sample_message_data)
assert result["status"] == "already_exists"
assert result["message_id"] == sample_message_data["message_id"]
# Verify no duplicate was created
messages = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.all()
)
assert len(messages) == 1
def test_add_discord_message_with_context(
db_session, sample_message_data, mock_discord_user, qdrant
):
"""Test that message is added successfully when previous messages exist."""
# Add a previous message
prev_msg = DiscordMessage(
message_id=111111,
channel_id=sample_message_data["channel_id"],
discord_user_id=mock_discord_user.id,
content="Previous message",
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"prev" + bytes(28),
)
db_session.add(prev_msg)
db_session.commit()
result = discord.add_discord_message(**sample_message_data)
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message is not None
assert result["status"] == "processed"
@patch("memory.workers.tasks.discord.process_discord_message")
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_add_discord_message_triggers_processing(
mock_process,
db_session,
sample_message_data,
mock_discord_server,
mock_discord_channel,
qdrant,
):
"""Test that add_discord_message triggers process_discord_message when conditions are met."""
mock_process.delay = Mock()
sample_message_data["server_id"] = mock_discord_server.id
discord.add_discord_message(**sample_message_data)
# Verify process_discord_message.delay was called
mock_process.delay.assert_called_once()
@patch("memory.workers.tasks.discord.process_discord_message")
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
def test_add_discord_message_no_processing_when_disabled(
mock_process, db_session, sample_message_data, qdrant
):
"""Test that process_discord_message is not called when processing is disabled."""
mock_process.delay = Mock()
discord.add_discord_message(**sample_message_data)
mock_process.delay.assert_not_called()
def test_edit_discord_message_success(db_session, sample_message_data, qdrant):
"""Test successful Discord message edit."""
# First add the message
discord.add_discord_message(**sample_message_data)
# Edit it
new_content = (
"This is the edited content with enough text to be meaningful and processed."
)
edited_at = "2024-01-01T13:00:00Z"
result = discord.edit_discord_message(
sample_message_data["message_id"],
new_content,
edited_at,
)
assert result["status"] == "processed"
# Verify the message was updated
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message.content == new_content
assert message.edited_at is not None
def test_edit_discord_message_not_found(db_session):
"""Test editing a message that doesn't exist."""
result = discord.edit_discord_message(
999999,
"New content",
"2024-01-01T13:00:00Z",
)
assert result["status"] == "error"
assert result["error"] == "Message not found"
assert result["message_id"] == 999999
def test_edit_discord_message_updates_context(
db_session, sample_message_data, mock_discord_user, qdrant
):
"""Test that editing a message works correctly."""
# Add previous message and the message to be edited
prev_msg = DiscordMessage(
message_id=111111,
channel_id=sample_message_data["channel_id"],
discord_user_id=mock_discord_user.id,
content="Previous message",
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"prev" + bytes(28),
)
db_session.add(prev_msg)
db_session.commit()
discord.add_discord_message(**sample_message_data)
# Edit the message
result = discord.edit_discord_message(
sample_message_data["message_id"],
"Edited content that should have context updated properly.",
"2024-01-01T13:00:00Z",
)
# Verify message was updated
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert (
message.content == "Edited content that should have context updated properly."
)
assert result["status"] == "processed"
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
"""Test processing a Discord message."""
# Add a message first
add_result = discord.add_discord_message(**sample_message_data)
message_id = add_result["discordmessage_id"]
# Process it
result = discord.process_discord_message(message_id)
assert result["status"] == "processed"
assert result["message_id"] == message_id
def test_process_discord_message_not_found(db_session):
"""Test processing a message that doesn't exist."""
result = discord.process_discord_message(999999)
assert result["status"] == "error"
assert result["error"] == "Message not found"
assert result["message_id"] == 999999
@pytest.mark.parametrize(
"sent_at_str,expected_hour",
[
("2024-01-01T12:00:00Z", 12),
("2024-01-01T00:00:00+00:00", 0),
("2024-01-01T23:59:59Z", 23),
],
)
def test_add_discord_message_datetime_parsing(
db_session, sample_message_data, sent_at_str, expected_hour, qdrant
):
"""Test that various datetime formats are parsed correctly."""
sample_message_data["sent_at"] = sent_at_str
discord.add_discord_message(**sample_message_data)
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message.sent_at.hour == expected_hour
def test_add_discord_message_unique_hash(db_session, sample_message_data, qdrant):
"""Test that message hash includes message_id for uniqueness."""
# Add first message
discord.add_discord_message(**sample_message_data)
# Try to add another message with same content but different message_id
sample_message_data["message_id"] = 888777666
result = discord.add_discord_message(**sample_message_data)
# Should succeed because hash includes message_id
assert result["status"] == "processed"
# Verify both messages exist
messages = (
db_session.query(DiscordMessage)
.filter_by(content=sample_message_data["content"])
.all()
)
assert len(messages) == 2
def test_get_prev_only_returns_messages_from_same_channel(
db_session, mock_discord_user, mock_discord_server
):
"""Test that get_prev only returns messages from the specified channel."""
# Create two channels
channel1 = DiscordChannel(
id=111,
name="channel-1",
channel_type="text",
server_id=mock_discord_server.id,
)
channel2 = DiscordChannel(
id=222,
name="channel-2",
channel_type="text",
server_id=mock_discord_server.id,
)
db_session.add_all([channel1, channel2])
db_session.commit()
# Add messages to both channels
msg1 = DiscordMessage(
message_id=1,
channel_id=channel1.id,
discord_user_id=mock_discord_user.id,
content="Message in channel 1",
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash1" + bytes(26),
)
msg2 = DiscordMessage(
message_id=2,
channel_id=channel2.id,
discord_user_id=mock_discord_user.id,
content="Message in channel 2",
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash2" + bytes(26),
)
db_session.add_all([msg1, msg2])
db_session.commit()
# Get previous messages for channel 1
result = discord.get_prev(
db_session,
channel1.id, # type: ignore
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
)
# Should only return message from channel 1
assert len(result) == 1
assert "Message in channel 1" in result[0]
assert "Message in channel 2" not in result[0]