mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 09:01:17 +01:00
Compare commits
3 Commits
f454aa9afa
...
99d3843f47
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99d3843f47 | ||
|
|
08d17c28dd | ||
|
|
e086b4a3a6 |
@ -9,7 +9,6 @@ Create Date: 2025-10-12 10:12:57.421009
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@ -20,10 +19,8 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename prompt column to message in scheduled_llm_calls table
|
||||
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename message column back to prompt in scheduled_llm_calls table
|
||||
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")
|
||||
|
||||
176
db/migrations/versions/20251012_222827_add_discord_models.py
Normal file
176
db/migrations/versions/20251012_222827_add_discord_models.py
Normal file
@ -0,0 +1,176 @@
|
||||
"""add_discord_models
|
||||
|
||||
Revision ID: a8c8e8b17179
|
||||
Revises: c86079073c1d
|
||||
Create Date: 2025-10-12 22:28:27.856164
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a8c8e8b17179"
|
||||
down_revision: Union[str, None] = "c86079073c1d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"discord_servers",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("member_count", sa.Integer(), nullable=True),
|
||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||
sa.Column(
|
||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||
),
|
||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"discord_servers_active_idx",
|
||||
"discord_servers",
|
||||
["track_messages", "last_sync_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"discord_channels",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("channel_type", sa.Text(), nullable=False),
|
||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||
sa.Column(
|
||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||
),
|
||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["server_id"],
|
||||
["discord_servers.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"discord_channels_server_idx", "discord_channels", ["server_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"discord_users",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("username", sa.Text(), nullable=False),
|
||||
sa.Column("display_name", sa.Text(), nullable=True),
|
||||
sa.Column("system_user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||
sa.Column(
|
||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||
),
|
||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["system_user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"discord_users_system_user_idx",
|
||||
"discord_users",
|
||||
["system_user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"discord_message",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("discord_user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_type", sa.Text(), server_default="default", nullable=True),
|
||||
sa.Column("reply_to_message_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("thread_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["channel_id"],
|
||||
["discord_channels.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["discord_user_id"],
|
||||
["discord_users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["server_id"],
|
||||
["discord_servers.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"discord_message_discord_id_idx", "discord_message", ["message_id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
"discord_message_server_channel_idx",
|
||||
"discord_message",
|
||||
["server_id", "channel_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("discord_message_user_idx", table_name="discord_message")
|
||||
op.drop_index("discord_message_server_channel_idx", table_name="discord_message")
|
||||
op.drop_index("discord_message_discord_id_idx", table_name="discord_message")
|
||||
op.drop_table("discord_message")
|
||||
op.drop_index("discord_users_system_user_idx", table_name="discord_users")
|
||||
op.drop_table("discord_users")
|
||||
op.drop_index("discord_channels_server_idx", table_name="discord_channels")
|
||||
op.drop_table("discord_channels")
|
||||
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
|
||||
op.drop_table("discord_servers")
|
||||
@ -174,7 +174,7 @@ services:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
|
||||
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
|
||||
|
||||
ingest-hub:
|
||||
<<: *worker-base
|
||||
@ -183,6 +183,10 @@ services:
|
||||
dockerfile: docker/ingest_hub/Dockerfile
|
||||
environment:
|
||||
<<: *worker-env
|
||||
DISCORD_API_PORT: 8000
|
||||
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
|
||||
DISCORD_NOTIFICATIONS_ENABLED: true
|
||||
DISCORD_COLLECTOR_ENABLED: true
|
||||
volumes:
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
tmpfs:
|
||||
|
||||
@ -11,10 +11,10 @@ RUN apt-get update && apt-get install -y \
|
||||
COPY requirements ./requirements/
|
||||
COPY setup.py ./
|
||||
RUN mkdir src
|
||||
RUN pip install -e ".[common]"
|
||||
RUN pip install -e ".[ingesters]"
|
||||
|
||||
COPY src/ ./src/
|
||||
RUN pip install -e ".[common]"
|
||||
RUN pip install -e ".[ingesters]"
|
||||
|
||||
# Create and copy entrypoint script
|
||||
COPY docker/workers/entry.sh ./entry.sh
|
||||
|
||||
@ -14,3 +14,12 @@ stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
|
||||
[program:discord-api]
|
||||
command=uvicorn memory.discord.api:app --host 0.0.0.0 --port %(ENV_DISCORD_API_PORT)s
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
|
||||
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
|
||||
git config --global user.name "${GIT_USER_NAME}"
|
||||
|
||||
# Default queues to process
|
||||
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
|
||||
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
ENTRYPOINT ["./entry.sh"]
|
||||
@ -5,7 +5,7 @@ alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.18.1
|
||||
anthropic==0.69.0
|
||||
openai==1.25.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
|
||||
3
requirements/requirements-ingesters.txt
Normal file
3
requirements/requirements-ingesters.txt
Normal file
@ -0,0 +1,3 @@
|
||||
discord.py==2.3.2
|
||||
uvicorn==0.29.0
|
||||
fastapi==0.112.2
|
||||
8
setup.py
8
setup.py
@ -17,6 +17,7 @@ common_requires = read_requirements("requirements-common.txt")
|
||||
parsers_requires = read_requirements("requirements-parsers.txt")
|
||||
api_requires = read_requirements("requirements-api.txt")
|
||||
dev_requires = read_requirements("requirements-dev.txt")
|
||||
ingesters_requires = read_requirements("requirements-ingesters.txt")
|
||||
|
||||
setup(
|
||||
name="memory",
|
||||
@ -28,6 +29,11 @@ setup(
|
||||
"api": api_requires + common_requires + parsers_requires,
|
||||
"common": common_requires + parsers_requires,
|
||||
"dev": dev_requires,
|
||||
"all": api_requires + common_requires + dev_requires + parsers_requires,
|
||||
"ingesters": common_requires + parsers_requires + ingesters_requires,
|
||||
"all": api_requires
|
||||
+ common_requires
|
||||
+ dev_requires
|
||||
+ parsers_requires
|
||||
+ ingesters_requires,
|
||||
},
|
||||
)
|
||||
|
||||
@ -40,7 +40,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
llms.call,
|
||||
llms.summarize,
|
||||
prompt,
|
||||
settings.RANKER_MODEL,
|
||||
images=images,
|
||||
|
||||
@ -12,6 +12,10 @@ MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
||||
NOTES_ROOT = "memory.workers.tasks.notes"
|
||||
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||
DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||
|
||||
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
||||
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
||||
@ -72,17 +76,18 @@ app.conf.update(
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
task_routes={
|
||||
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
||||
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
||||
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
||||
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
||||
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
||||
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
||||
f"{DISCORD_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
|
||||
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
||||
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
|
||||
f"{MAINTENANCE_ROOT}.*": {
|
||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
|
||||
},
|
||||
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
||||
f"{SCHEDULED_CALLS_ROOT}.*": {
|
||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
||||
},
|
||||
|
||||
@ -11,6 +11,7 @@ from memory.common.db.models.source_items import (
|
||||
EmailAttachment,
|
||||
AgentObservation,
|
||||
ChatMessage,
|
||||
DiscordMessage,
|
||||
BlogPost,
|
||||
Comic,
|
||||
BookSection,
|
||||
@ -40,6 +41,9 @@ from memory.common.db.models.sources import (
|
||||
Book,
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
from memory.common.db.models.users import (
|
||||
User,
|
||||
@ -74,6 +78,7 @@ __all__ = [
|
||||
"EmailAttachment",
|
||||
"AgentObservation",
|
||||
"ChatMessage",
|
||||
"DiscordMessage",
|
||||
"BlogPost",
|
||||
"Comic",
|
||||
"BookSection",
|
||||
@ -93,6 +98,9 @@ __all__ = [
|
||||
"Book",
|
||||
"ArticleFeed",
|
||||
"EmailAccount",
|
||||
"DiscordServer",
|
||||
"DiscordChannel",
|
||||
"DiscordUser",
|
||||
# Users
|
||||
"User",
|
||||
"UserSession",
|
||||
|
||||
@ -70,7 +70,7 @@ class ScheduledLLMCall(Base):
|
||||
"created_at": print_datetime(cast(datetime, self.created_at)),
|
||||
"executed_at": print_datetime(cast(datetime, self.executed_at)),
|
||||
"model": self.model,
|
||||
"prompt": self.message,
|
||||
"message": self.message,
|
||||
"system_prompt": self.system_prompt,
|
||||
"allowed_tools": self.allowed_tools,
|
||||
"discord_channel": self.discord_channel,
|
||||
|
||||
@ -262,7 +262,7 @@ class ChatMessage(SourceItem):
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
platform = Column(Text)
|
||||
channel_id = Column(Text)
|
||||
channel_id = Column(Text) # Keep as Text for cross-platform compatibility
|
||||
author = Column(Text)
|
||||
sent_at = Column(DateTime(timezone=True))
|
||||
|
||||
@ -274,6 +274,64 @@ class ChatMessage(SourceItem):
|
||||
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
|
||||
|
||||
|
||||
class DiscordMessage(SourceItem):
|
||||
"""Discord-specific chat message with rich metadata"""
|
||||
|
||||
__tablename__ = "discord_message"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
sent_at = Column(DateTime(timezone=True), nullable=False)
|
||||
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
|
||||
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
|
||||
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
|
||||
|
||||
# Discord-specific metadata
|
||||
message_type = Column(
|
||||
Text, server_default="default"
|
||||
) # "default", "reply", "thread_starter"
|
||||
reply_to_message_id = Column(
|
||||
BigInteger, nullable=True
|
||||
) # Discord message snowflake ID if replying
|
||||
thread_id = Column(
|
||||
BigInteger, nullable=True
|
||||
) # Discord thread snowflake ID if in thread
|
||||
edited_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
||||
server = relationship("DiscordServer", foreign_keys=[server_id])
|
||||
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return f"{self.discord_user.username}: {self.content}"
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "discord_message",
|
||||
}
|
||||
|
||||
__table_args__ = (
|
||||
Index("discord_message_discord_id_idx", "message_id", unique=True),
|
||||
Index(
|
||||
"discord_message_server_channel_idx",
|
||||
"server_id",
|
||||
"channel_id",
|
||||
),
|
||||
Index("discord_message_user_idx", "discord_user_id"),
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
content = cast(str | None, self.content)
|
||||
if not content:
|
||||
return []
|
||||
prev = getattr(self, "messages_before", [])
|
||||
content = "\n\n".join(prev) + "\n\n" + self.title
|
||||
return extract.extract_text(content)
|
||||
|
||||
|
||||
class GitCommit(SourceItem):
|
||||
__tablename__ = "git_commit"
|
||||
|
||||
|
||||
@ -10,12 +10,14 @@ from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
@ -123,3 +125,74 @@ class EmailAccount(Base):
|
||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
|
||||
class MessageProcessor:
|
||||
track_messages = Column(Boolean, nullable=False, server_default="true")
|
||||
ignore_messages = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
|
||||
class DiscordServer(Base, MessageProcessor):
|
||||
"""Discord server configuration and metadata"""
|
||||
|
||||
__tablename__ = "discord_servers"
|
||||
|
||||
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
|
||||
name = Column(Text, nullable=False)
|
||||
description = Column(Text)
|
||||
member_count = Column(Integer)
|
||||
|
||||
# Collection settings
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
channels = relationship(
|
||||
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
|
||||
)
|
||||
|
||||
|
||||
class DiscordChannel(Base, MessageProcessor):
|
||||
"""Discord channel metadata and configuration"""
|
||||
|
||||
__tablename__ = "discord_channels"
|
||||
|
||||
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
|
||||
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||
name = Column(Text, nullable=False)
|
||||
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
|
||||
|
||||
# Collection settings (null = inherit from server)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
server = relationship("DiscordServer", back_populates="channels")
|
||||
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
|
||||
|
||||
|
||||
class DiscordUser(Base, MessageProcessor):
|
||||
"""Discord user metadata and preferences"""
|
||||
|
||||
__tablename__ = "discord_users"
|
||||
|
||||
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
|
||||
username = Column(Text, nullable=False)
|
||||
display_name = Column(Text)
|
||||
|
||||
# Link to system user if registered
|
||||
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# Basic DM settings
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
system_user = relationship("User", back_populates="discord_users")
|
||||
|
||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||
|
||||
@ -50,6 +50,7 @@ class User(Base):
|
||||
oauth_states = relationship(
|
||||
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
discord_users = relationship("DiscordUser", back_populates="system_user")
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
|
||||
@ -1,221 +1,101 @@
|
||||
"""
|
||||
Discord integration.
|
||||
|
||||
Simple HTTP client that communicates with the Discord collector's API server.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ERROR_CHANNEL = "memory-errors"
|
||||
ACTIVITY_CHANNEL = "memory-activity"
|
||||
DISCOVERY_CHANNEL = "memory-discoveries"
|
||||
CHAT_CHANNEL = "memory-chat"
|
||||
|
||||
def get_api_url() -> str:
|
||||
"""Get the Discord API server URL"""
|
||||
host = settings.DISCORD_COLLECTOR_SERVER_URL
|
||||
port = settings.DISCORD_COLLECTOR_PORT
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
class DiscordServer(requests.Session):
|
||||
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
|
||||
self.server_id = server_id
|
||||
self.server_name = server_name
|
||||
self.channels = {}
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_channels()
|
||||
self.members = self.fetch_all_members()
|
||||
|
||||
def setup_channels(self):
|
||||
resp = self.get(self.channels_url)
|
||||
resp.raise_for_status()
|
||||
channels = {channel["name"]: channel["id"] for channel in resp.json()}
|
||||
|
||||
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
|
||||
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
|
||||
self.channels[ERROR_CHANNEL] = error_channel
|
||||
|
||||
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
|
||||
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
|
||||
self.channels[ACTIVITY_CHANNEL] = activity_channel
|
||||
|
||||
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
|
||||
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
|
||||
self.channels[DISCOVERY_CHANNEL] = discovery_channel
|
||||
|
||||
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
|
||||
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
|
||||
self.channels[CHAT_CHANNEL] = chat_channel
|
||||
|
||||
@property
|
||||
def error_channel(self) -> str:
|
||||
return self.channels[ERROR_CHANNEL]
|
||||
|
||||
@property
|
||||
def activity_channel(self) -> str:
|
||||
return self.channels[ACTIVITY_CHANNEL]
|
||||
|
||||
@property
|
||||
def discovery_channel(self) -> str:
|
||||
return self.channels[DISCOVERY_CHANNEL]
|
||||
|
||||
@property
|
||||
def chat_channel(self) -> str:
|
||||
return self.channels[CHAT_CHANNEL]
|
||||
|
||||
def channel_id(self, channel_name: str) -> str:
|
||||
if not (channel_id := self.channels.get(channel_name)):
|
||||
raise ValueError(f"Channel {channel_name} not found")
|
||||
return channel_id
|
||||
|
||||
def send_message(self, channel_id: str, content: str):
|
||||
payload: dict[str, Any] = {"content": content}
|
||||
mentions = re.findall(r"@(\S*)", content)
|
||||
users = {u: i for u, i in self.members.items() if u in mentions}
|
||||
if users:
|
||||
for u, i in users.items():
|
||||
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
|
||||
payload["allowed_mentions"] = {
|
||||
"parse": [],
|
||||
"users": list(users.values()),
|
||||
}
|
||||
|
||||
return self.post(
|
||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
||||
resp = self.post(
|
||||
self.channels_url, json={"name": channel_name, "type": channel_type}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["id"]
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
|
||||
)
|
||||
|
||||
def request(self, method: str, url: str, **kwargs):
|
||||
headers = kwargs.get("headers", {})
|
||||
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
kwargs["headers"] = headers
|
||||
return super().request(method, url, **kwargs)
|
||||
|
||||
@property
|
||||
def channels_url(self) -> str:
|
||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
||||
|
||||
@property
|
||||
def members_url(self) -> str:
|
||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
|
||||
|
||||
@property
|
||||
def dm_create_url(self) -> str:
|
||||
return "https://discord.com/api/v10/users/@me/channels"
|
||||
|
||||
def list_members(
|
||||
self, limit: int = 1000, after: str | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List up to `limit` members in this guild, starting after a user ID.
|
||||
|
||||
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
|
||||
"""
|
||||
params: dict[str, Any] = {"limit": limit}
|
||||
if after:
|
||||
params["after"] = after
|
||||
resp = self.get(self.members_url, params=params)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
|
||||
"""Retrieve all members in the guild by paginating the members list.
|
||||
|
||||
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
|
||||
"""
|
||||
members: dict[str, str] = {}
|
||||
after: str | None = None
|
||||
while batch := self.list_members(limit=page_size, after=after):
|
||||
for member in batch:
|
||||
user = member.get("user", {})
|
||||
members[user.get("global_name") or user.get("username", "")] = user.get(
|
||||
"id", ""
|
||||
)
|
||||
after = user.get("id", "")
|
||||
return members
|
||||
|
||||
def create_dm_channel(self, user_id: str) -> str:
|
||||
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
|
||||
|
||||
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
|
||||
"""
|
||||
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["id"]
|
||||
|
||||
def send_dm(self, user_id: str, content: str):
|
||||
"""Send a direct message to a specific user by ID."""
|
||||
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
|
||||
return self.post(
|
||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||
json={"content": content},
|
||||
)
|
||||
|
||||
|
||||
def get_bot_servers() -> list[dict[str, Any]]:
|
||||
"""Get list of servers the bot is in."""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
return []
|
||||
|
||||
def send_dm(user_identifier: str, message: str) -> bool:
|
||||
"""Send a DM via the Discord collector API"""
|
||||
try:
|
||||
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
|
||||
response = requests.get(
|
||||
"https://discord.com/api/v10/users/@me/guilds", headers=headers
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_dm",
|
||||
json={"user_identifier": user_identifier, "message": message},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def broadcast_message(channel_name: str, message: str) -> bool:
|
||||
"""Send a message to a channel via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_channel",
|
||||
json={"channel_name": channel_name, "message": message},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_collector_healthy() -> bool:
|
||||
"""Check if the Discord collector is running and healthy"""
|
||||
try:
|
||||
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("status") == "healthy"
|
||||
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
def refresh_discord_metadata() -> dict[str, int] | None:
|
||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||
try:
|
||||
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bot servers: {e}")
|
||||
return []
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to refresh Discord metadata: {e}")
|
||||
return None
|
||||
|
||||
|
||||
servers: dict[str, DiscordServer] = {}
|
||||
# Convenience functions
|
||||
def send_error_message(message: str) -> bool:
|
||||
"""Send an error message to the error channel"""
|
||||
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
|
||||
|
||||
|
||||
def load_servers():
|
||||
for server in get_bot_servers():
|
||||
servers[server["id"]] = DiscordServer(server["id"], server["name"])
|
||||
def send_activity_message(message: str) -> bool:
|
||||
"""Send an activity message to the activity channel"""
|
||||
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||
|
||||
|
||||
def broadcast_message(channel: str, message: str):
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return
|
||||
|
||||
for server in servers.values():
|
||||
server.send_message(server.channel_id(channel), message)
|
||||
def send_discovery_message(message: str) -> bool:
|
||||
"""Send a discovery message to the discovery channel"""
|
||||
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_error_message(message: str):
|
||||
broadcast_message(ERROR_CHANNEL, message)
|
||||
|
||||
|
||||
def send_activity_message(message: str):
|
||||
broadcast_message(ACTIVITY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_discovery_message(message: str):
|
||||
broadcast_message(DISCOVERY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_chat_message(message: str):
|
||||
broadcast_message(CHAT_CHANNEL, message)
|
||||
|
||||
|
||||
def send_dm(user_id: str, message: str):
|
||||
for server in servers.values():
|
||||
if not server.members.get(user_id) and user_id not in server.members.values():
|
||||
continue
|
||||
|
||||
server.send_dm(user_id, message)
|
||||
def send_chat_message(message: str) -> bool:
|
||||
"""Send a chat message to the chat channel"""
|
||||
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
|
||||
|
||||
|
||||
def notify_task_failure(
|
||||
@ -234,9 +114,6 @@ def notify_task_failure(
|
||||
task_args: Task arguments
|
||||
task_kwargs: Task keyword arguments
|
||||
traceback_str: Full traceback string
|
||||
|
||||
Returns:
|
||||
True if notification sent successfully
|
||||
"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
logger.debug("Discord notifications disabled")
|
||||
|
||||
@ -1,122 +0,0 @@
|
||||
import logging
|
||||
import base64
|
||||
import io
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||
"""
|
||||
|
||||
|
||||
def encode_image(image: Image.Image) -> str:
|
||||
"""Encode PIL Image to base64 string."""
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary (for RGBA, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format="JPEG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def call_openai(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
user_content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call_anthropic(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
# Prepare the message content
|
||||
content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
# Add images if provided
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
content.append(
|
||||
{ # type: ignore
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[{"role": "user", "content": content}], # type: ignore
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
if model.startswith("anthropic"):
|
||||
return call_anthropic(prompt, model, images, system_prompt)
|
||||
return call_openai(prompt, model, images, system_prompt)
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
||||
79
src/memory/common/llms/__init__.py
Normal file
79
src/memory/common/llms/__init__.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""LLM provider module for unified LLM access."""
|
||||
|
||||
# Legacy imports for backwards compatibility
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# New provider system
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageContent,
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolDefinition,
|
||||
ToolResultContent,
|
||||
ToolUseContent,
|
||||
create_provider,
|
||||
)
|
||||
from memory.common import tokens
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMProvider",
|
||||
"Message",
|
||||
"MessageRole",
|
||||
"MessageContent",
|
||||
"TextContent",
|
||||
"ImageContent",
|
||||
"ToolUseContent",
|
||||
"ToolResultContent",
|
||||
"ThinkingContent",
|
||||
"ToolDefinition",
|
||||
"StreamEvent",
|
||||
"LLMSettings",
|
||||
"create_provider",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def summarize(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = "",
|
||||
) -> str:
|
||||
provider = create_provider(model=model)
|
||||
try:
|
||||
# Build message content
|
||||
content: list[MessageContent] = [TextContent(text=prompt)]
|
||||
for image in images:
|
||||
content.append(ImageContent(image=image))
|
||||
|
||||
messages = [Message(role=MessageRole.USER, content=content)]
|
||||
settings_obj = LLMSettings(temperature=0.3, max_tokens=2048)
|
||||
|
||||
res = provider.run_with_tools(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt
|
||||
or "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
||||
settings=settings_obj,
|
||||
tools={},
|
||||
)
|
||||
return res.response or ""
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
||||
451
src/memory/common/llms/anthropic_provider.py
Normal file
451
src/memory/common/llms/anthropic_provider.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""Anthropic LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Iterator, Optional
|
||||
|
||||
import anthropic
|
||||
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
ToolDefinition,
|
||||
ToolUseContent,
|
||||
ThinkingContent,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
||||
|
||||
# Models that support extended thinking
|
||||
THINKING_MODELS = {
|
||||
"claude-opus-4",
|
||||
"claude-opus-4-1",
|
||||
"claude-sonnet-4-0",
|
||||
"claude-sonnet-3-7",
|
||||
"claude-sonnet-4-5",
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, model: str, enable_thinking: bool = False):
|
||||
"""
|
||||
Initialize the Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model identifier
|
||||
enable_thinking: Enable extended thinking for supported models
|
||||
"""
|
||||
super().__init__(api_key, model)
|
||||
self.enable_thinking = enable_thinking
|
||||
self._async_client: Optional[anthropic.AsyncAnthropic] = None
|
||||
|
||||
def _initialize_client(self) -> anthropic.Anthropic:
|
||||
"""Initialize the Anthropic client."""
|
||||
return anthropic.Anthropic(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
def async_client(self) -> anthropic.AsyncAnthropic:
|
||||
"""Lazy-load the async client."""
|
||||
if self._async_client is None:
|
||||
self._async_client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
||||
return self._async_client
|
||||
|
||||
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||
"""Convert ImageContent to Anthropic's base64 source format."""
|
||||
encoded_image = self.encode_image(content.image)
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_message(self, message: Message) -> dict[str, Any]:
|
||||
converted = message.to_dict()
|
||||
if converted["role"] == MessageRole.ASSISTANT and isinstance(
|
||||
converted["content"], list
|
||||
):
|
||||
content = sorted(
|
||||
converted["content"], key=lambda x: x["type"] != "thinking"
|
||||
)
|
||||
return converted | {"content": content}
|
||||
return converted
|
||||
|
||||
def _should_include_message(self, message: Message) -> bool:
|
||||
"""Filter out system messages (handled separately in Anthropic)."""
|
||||
return message.role != MessageRole.SYSTEM
|
||||
|
||||
def _supports_thinking(self) -> bool:
|
||||
"""Check if the current model supports extended thinking."""
|
||||
model_lower = self.model.lower()
|
||||
return any(supported in model_lower for supported in self.THINKING_MODELS)
|
||||
|
||||
def _build_request_kwargs(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None,
|
||||
tools: list[ToolDefinition] | None,
|
||||
settings: LLMSettings,
|
||||
) -> dict[str, Any]:
|
||||
"""Build common request kwargs for API calls."""
|
||||
anthropic_messages = self._convert_messages(messages)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": anthropic_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
}
|
||||
|
||||
# Only include top_p if explicitly set
|
||||
if settings.top_p is not None:
|
||||
kwargs["top_p"] = settings.top_p
|
||||
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop_sequences"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
|
||||
# Enable extended thinking if requested and model supports it
|
||||
if self.enable_thinking and self._supports_thinking():
|
||||
thinking_budget = min(10000, settings.max_tokens - 1024)
|
||||
if thinking_budget >= 1024:
|
||||
kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": thinking_budget,
|
||||
}
|
||||
# When thinking is enabled: temperature must be 1, can't use top_p
|
||||
kwargs["temperature"] = 1.0
|
||||
kwargs.pop("top_p", None)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _handle_stream_event(
|
||||
self, event: Any, current_tool_use: dict[str, Any] | None
|
||||
) -> tuple[StreamEvent | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Handle a streaming event and return StreamEvent and updated tool state.
|
||||
|
||||
Returns:
|
||||
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
||||
"""
|
||||
event_type = getattr(event, "type", None)
|
||||
|
||||
# Handle error events
|
||||
if event_type == "error":
|
||||
error = getattr(event, "error", None)
|
||||
error_msg = str(error) if error else "Unknown error"
|
||||
return StreamEvent(type="error", data=error_msg), current_tool_use
|
||||
|
||||
if event_type == "content_block_start":
|
||||
block = getattr(event, "content_block", None)
|
||||
if not block:
|
||||
return None, current_tool_use
|
||||
|
||||
block_type = getattr(block, "type", None)
|
||||
|
||||
# Handle various tool types (tool_use, mcp_tool_use, server_tool_use)
|
||||
if block_type in ("tool_use", "mcp_tool_use", "server_tool_use"):
|
||||
# In content_block_start, input may already be present (empty dict)
|
||||
block_input = getattr(block, "input", None)
|
||||
current_tool_use = {
|
||||
"id": getattr(block, "id", ""),
|
||||
"name": getattr(block, "name", ""),
|
||||
"input": block_input if block_input is not None else "",
|
||||
"server_name": getattr(block, "server_name", None),
|
||||
"is_server_call": block_type != "tool_use",
|
||||
}
|
||||
|
||||
# Handle tool result blocks
|
||||
elif hasattr(block, "tool_use_id"):
|
||||
tool_result = {
|
||||
"id": getattr(block, "tool_use_id", ""),
|
||||
"result": getattr(block, "content", ""),
|
||||
}
|
||||
return StreamEvent(
|
||||
type="tool_result", data=tool_result
|
||||
), current_tool_use
|
||||
|
||||
# For non-tool blocks (text, thinking), we don't need to track state
|
||||
return None, current_tool_use
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = getattr(event, "delta", None)
|
||||
if not delta:
|
||||
return None, current_tool_use
|
||||
|
||||
delta_type = getattr(delta, "type", None)
|
||||
|
||||
if delta_type == "text_delta":
|
||||
text = getattr(delta, "text", "")
|
||||
return StreamEvent(type="text", data=text), current_tool_use
|
||||
|
||||
elif delta_type == "thinking_delta":
|
||||
thinking = getattr(delta, "thinking", "")
|
||||
return StreamEvent(type="thinking", data=thinking), current_tool_use
|
||||
|
||||
elif delta_type == "signature_delta":
|
||||
# Handle thinking signature for extended thinking
|
||||
signature = getattr(delta, "signature", "")
|
||||
return StreamEvent(
|
||||
type="thinking", signature=signature
|
||||
), current_tool_use
|
||||
|
||||
elif delta_type == "input_json_delta":
|
||||
if current_tool_use is None:
|
||||
# Edge case: received input_json_delta without tool_use start
|
||||
logger.warning("Received input_json_delta without tool_use context")
|
||||
return None, None
|
||||
|
||||
# Only accumulate if input is still a string (being built up)
|
||||
if isinstance(current_tool_use.get("input"), str):
|
||||
partial_json = getattr(delta, "partial_json", "")
|
||||
current_tool_use["input"] += partial_json
|
||||
# else: input was already set as a dict in content_block_start
|
||||
|
||||
return None, current_tool_use
|
||||
|
||||
elif event_type == "content_block_stop":
|
||||
if current_tool_use:
|
||||
# Use the parsed input from the content block if available
|
||||
# This handles empty inputs {} more reliably than parsing
|
||||
content_block = getattr(event, "content_block", None)
|
||||
if content_block and hasattr(content_block, "input"):
|
||||
current_tool_use["input"] = content_block.input
|
||||
else:
|
||||
# Fallback: parse accumulated JSON string
|
||||
input_str = current_tool_use.get("input", "")
|
||||
if isinstance(input_str, str):
|
||||
# Need to parse the accumulated string
|
||||
if not input_str or input_str.isspace():
|
||||
# Empty or whitespace-only input
|
||||
current_tool_use["input"] = {}
|
||||
else:
|
||||
try:
|
||||
current_tool_use["input"] = json.loads(input_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to parse tool input '{input_str}': {e}"
|
||||
)
|
||||
current_tool_use["input"] = {}
|
||||
# else: input is already parsed
|
||||
|
||||
tool_data = {
|
||||
"id": current_tool_use.get("id", ""),
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"input": current_tool_use.get("input", {}),
|
||||
}
|
||||
# Include server info if present
|
||||
if current_tool_use.get("server_name"):
|
||||
tool_data["server_name"] = current_tool_use["server_name"]
|
||||
if current_tool_use.get("is_server_call"):
|
||||
tool_data["is_server_call"] = current_tool_use["is_server_call"]
|
||||
|
||||
return StreamEvent(type="tool_use", data=tool_data), None
|
||||
|
||||
elif event_type == "message_delta":
|
||||
delta = getattr(event, "delta", None)
|
||||
if delta:
|
||||
stop_reason = getattr(delta, "stop_reason", None)
|
||||
if stop_reason == "max_tokens":
|
||||
return StreamEvent(
|
||||
type="error", data="Max tokens reached"
|
||||
), current_tool_use
|
||||
|
||||
# Handle token usage information
|
||||
usage = getattr(event, "usage", None)
|
||||
if usage:
|
||||
usage_data = {
|
||||
"input_tokens": getattr(usage, "input_tokens", 0),
|
||||
"output_tokens": getattr(usage, "output_tokens", 0),
|
||||
"cache_creation_input_tokens": getattr(
|
||||
usage, "cache_creation_input_tokens", None
|
||||
),
|
||||
"cache_read_input_tokens": getattr(
|
||||
usage, "cache_read_input_tokens", None
|
||||
),
|
||||
}
|
||||
# Could emit this as a separate event type if needed
|
||||
logger.debug(f"Token usage: {usage_data}")
|
||||
|
||||
return None, current_tool_use
|
||||
|
||||
elif event_type == "message_stop":
|
||||
# Final event - clean up any pending state
|
||||
if current_tool_use:
|
||||
logger.warning(
|
||||
f"Message ended with incomplete tool use: {current_tool_use}"
|
||||
)
|
||||
return StreamEvent(type="done"), None
|
||||
|
||||
# Unknown event type - log but don't fail
|
||||
if event_type and event_type not in (
|
||||
"message_start",
|
||||
"message_delta",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"message_stop",
|
||||
):
|
||||
logger.debug(f"Unknown event type: {event_type}")
|
||||
|
||||
return None, current_tool_use
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
response = self.client.messages.create(**kwargs)
|
||||
return "".join(
|
||||
block.text for block in response.content if block.type == "text"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
with self.client.messages.stream(**kwargs) as stream:
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
|
||||
for event in stream:
|
||||
stream_event, current_tool_use = self._handle_stream_event(
|
||||
event, current_tool_use
|
||||
)
|
||||
if stream_event:
|
||||
yield stream_event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
response = await self.async_client.messages.create(**kwargs)
|
||||
return "".join(
|
||||
block.text for block in response.content if block.type == "text"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
|
||||
try:
|
||||
async with self.async_client.messages.stream(**kwargs) as stream:
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
|
||||
async for event in stream:
|
||||
stream_event, current_tool_use = self._handle_stream_event(
|
||||
event, current_tool_use
|
||||
)
|
||||
if stream_event:
|
||||
yield stream_event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
|
||||
def stream_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Iterator[StreamEvent]:
|
||||
if max_iterations <= 0:
|
||||
return
|
||||
|
||||
response = TextContent(text="")
|
||||
thinking = ThinkingContent(thinking="", signature="")
|
||||
|
||||
for event in self.stream(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
tools=list(tools.values()),
|
||||
settings=settings,
|
||||
):
|
||||
if event.type == "text":
|
||||
response.text += event.data
|
||||
yield event
|
||||
elif event.type == "thinking" and event.signature:
|
||||
thinking.signature = event.signature
|
||||
elif event.type == "thinking":
|
||||
thinking.thinking += event.data
|
||||
yield event
|
||||
elif event.type == "tool_use":
|
||||
yield event
|
||||
tool_result = self.execute_tool(event.data, tools)
|
||||
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||
messages.append(
|
||||
Message.assistant(
|
||||
response,
|
||||
thinking,
|
||||
ToolUseContent(
|
||||
id=event.data["id"],
|
||||
name=event.data["name"],
|
||||
input=event.data["input"],
|
||||
),
|
||||
)
|
||||
)
|
||||
messages.append(Message.user(tool_result=tool_result))
|
||||
yield from self.stream_with_tools(
|
||||
messages, tools, settings, system_prompt, max_iterations - 1
|
||||
)
|
||||
elif event.type == "tool_result":
|
||||
yield event
|
||||
elif event.type == "error":
|
||||
logger.error(f"LLM error: {event.data}")
|
||||
raise RuntimeError(f"LLM error: {event.data}")
|
||||
561
src/memory/common/llms/base.py
Normal file
561
src/memory/common/llms/base.py
Normal file
@ -0,0 +1,561 @@
|
||||
"""Base classes and types for LLM providers."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message roles for chat history."""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
"""Text content in a message."""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {"type": "text", "text": self.text}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
"""Image content in a message."""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: Image.Image = None # type: ignore
|
||||
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
# Note: Image will be encoded by provider-specific implementation
|
||||
return {"type": "image", "image": self.image}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUseContent:
|
||||
"""Tool use request from the assistant."""
|
||||
|
||||
type: Literal["tool_use"] = "tool_use"
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
input: dict[str, Any] = None # type: ignore
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"type": "tool_use",
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"input": self.input,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResultContent:
|
||||
"""Tool result from tool execution."""
|
||||
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str = ""
|
||||
content: str = ""
|
||||
is_error: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": self.tool_use_id,
|
||||
"content": self.content,
|
||||
"is_error": self.is_error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThinkingContent:
|
||||
"""Thinking/reasoning content from the assistant (extended thinking)."""
|
||||
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str = ""
|
||||
signature: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"type": "thinking",
|
||||
"thinking": self.thinking,
|
||||
"signature": self.signature,
|
||||
}
|
||||
|
||||
|
||||
MessageContent = Union[
|
||||
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Turn:
|
||||
"""A turn in the conversation."""
|
||||
|
||||
response: str | None
|
||||
thinking: str | None
|
||||
tool_calls: dict[str, ToolResult] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A message in the conversation history."""
|
||||
|
||||
role: MessageRole
|
||||
content: Union[str, list[MessageContent]]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert message to dictionary format."""
|
||||
if isinstance(self.content, str):
|
||||
return {"role": self.role.value, "content": self.content}
|
||||
content_list = [item.to_dict() for item in self.content]
|
||||
return {"role": self.role.value, "content": content_list}
|
||||
|
||||
@staticmethod
|
||||
def assistant(
|
||||
text: TextContent | None = None,
|
||||
thinking: ThinkingContent | None = None,
|
||||
tool_use: ToolUseContent | None = None,
|
||||
) -> "Message":
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(text)
|
||||
if thinking:
|
||||
parts.append(thinking)
|
||||
if tool_use:
|
||||
parts.append(tool_use)
|
||||
return Message(role=MessageRole.ASSISTANT, content=parts)
|
||||
|
||||
@staticmethod
|
||||
def user(
|
||||
text: str | None = None, tool_result: ToolResultContent | None = None
|
||||
) -> "Message":
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(TextContent(text=text))
|
||||
if tool_result:
|
||||
parts.append(tool_result)
|
||||
return Message(role=MessageRole.USER, content=parts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEvent:
|
||||
"""An event from the streaming response."""
|
||||
|
||||
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
|
||||
data: Any = None
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSettings:
|
||||
"""Settings for LLM API calls."""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 2048
|
||||
# Don't set by default - some models don't allow both temp and top_p
|
||||
top_p: float | None = None
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Base class for LLM providers."""
|
||||
|
||||
def __init__(self, api_key: str, model: str):
|
||||
"""
|
||||
Initialize the LLM provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for the provider
|
||||
model: Model identifier
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self._client: Any = None
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_client(self) -> Any:
|
||||
"""Initialize the provider-specific client."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Lazy-load the client."""
|
||||
if self._client is None:
|
||||
self._client = self._initialize_client()
|
||||
return self._client
|
||||
|
||||
def execute_tool(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
tool_handlers: dict[str, ToolDefinition],
|
||||
) -> ToolResultContent:
|
||||
"""
|
||||
Execute a tool call.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call
|
||||
tool_handlers: Dict mapping tool names to handler functions
|
||||
|
||||
Returns:
|
||||
ToolResultContent with result or error
|
||||
"""
|
||||
name = tool_call.get("name")
|
||||
tool_use_id = tool_call.get("id")
|
||||
input = tool_call.get("input")
|
||||
|
||||
if not name:
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content="Tool name missing",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
if not (tool := tool_handlers.get(name)):
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content=f"Tool '{name}' not found",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
try:
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content=tool(input),
|
||||
is_error=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool '{name}' failed: {e}", exc_info=True)
|
||||
return ToolResultContent(
|
||||
tool_use_id=tool_use_id,
|
||||
content=str(e),
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
def encode_image(self, image: Image.Image) -> str:
|
||||
"""
|
||||
Encode PIL Image to base64 string.
|
||||
|
||||
Args:
|
||||
image: PIL Image to encode
|
||||
|
||||
Returns:
|
||||
Base64 encoded string
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary (for RGBA, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format="JPEG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
|
||||
"""Convert TextContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||
"""Convert ImageContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
|
||||
"""Convert ToolUseContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_tool_result_content(
|
||||
self, content: ToolResultContent
|
||||
) -> dict[str, Any]:
|
||||
"""Convert ToolResultContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_thinking_content(self, content: ThinkingContent) -> dict[str, Any]:
|
||||
"""Convert ThinkingContent to provider format. Override for custom format."""
|
||||
return content.to_dict()
|
||||
|
||||
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a MessageContent item to provider format.
|
||||
|
||||
Dispatches to type-specific converters that can be overridden.
|
||||
"""
|
||||
if isinstance(content, TextContent):
|
||||
return self._convert_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
return self._convert_image_content(content)
|
||||
elif isinstance(content, ToolUseContent):
|
||||
return self._convert_tool_use_content(content)
|
||||
elif isinstance(content, ToolResultContent):
|
||||
return self._convert_tool_result_content(content)
|
||||
elif isinstance(content, ThinkingContent):
|
||||
return self._convert_thinking_content(content)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(content)}")
|
||||
|
||||
def _convert_message(self, message: Message) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a Message to provider format.
|
||||
|
||||
Can be overridden for provider-specific handling (e.g., filtering system messages).
|
||||
"""
|
||||
return message.to_dict()
|
||||
|
||||
def _should_include_message(self, message: Message) -> bool:
|
||||
"""
|
||||
Determine if a message should be included in the request.
|
||||
|
||||
Override to filter messages (e.g., Anthropic filters SYSTEM messages).
|
||||
|
||||
Args:
|
||||
message: Message to check
|
||||
|
||||
Returns:
|
||||
True if message should be included
|
||||
"""
|
||||
return True
|
||||
|
||||
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert a list of messages to provider format.
|
||||
|
||||
Uses _should_include_message for filtering and _convert_message for conversion.
|
||||
"""
|
||||
return [
|
||||
self._convert_message(msg)
|
||||
for msg in messages
|
||||
if self._should_include_message(msg)
|
||||
]
|
||||
|
||||
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a single ToolDefinition to provider format.
|
||||
|
||||
Default format matches Anthropic. Override for other providers (e.g., OpenAI uses functions).
|
||||
"""
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.input_schema,
|
||||
}
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: list[ToolDefinition] | None
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""Convert tool definitions to provider format."""
|
||||
if not tools:
|
||||
return None
|
||||
return [self._convert_tool(tool) for tool in tools]
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Yields:
|
||||
StreamEvent objects containing text chunks, tool uses, or errors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a non-streaming response asynchronously.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def astream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""
|
||||
Generate a streaming response asynchronously.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools the LLM can use
|
||||
settings: Optional settings for the generation
|
||||
|
||||
Yields:
|
||||
StreamEvent objects containing text chunks, tool uses, or errors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stream_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Iterator[StreamEvent]:
|
||||
pass
|
||||
|
||||
def run_with_tools(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> Turn:
|
||||
thinking, response, tool_calls = "", "", {}
|
||||
for event in self.stream_with_tools(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
settings=settings,
|
||||
system_prompt=system_prompt,
|
||||
max_iterations=max_iterations,
|
||||
):
|
||||
if event.type == "thinking":
|
||||
thinking += event.data
|
||||
elif event.type == "tool_use":
|
||||
tool_calls[event.data["id"]] = {
|
||||
"name": event.data["name"],
|
||||
"input": event.data["input"],
|
||||
"output": "",
|
||||
}
|
||||
elif event.type == "text":
|
||||
response += event.data
|
||||
elif event.type == "tool_result":
|
||||
current = tool_calls.get(event.data["tool_use_id"]) or {}
|
||||
tool_calls[event.data["tool_use_id"]] = {
|
||||
"name": event.data.get("name") or current.get("name"),
|
||||
"input": event.data.get("input") or current.get("input"),
|
||||
"output": event.data.get("content"),
|
||||
}
|
||||
return Turn(
|
||||
thinking=thinking or None,
|
||||
response=response or None,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
|
||||
def create_provider(
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
enable_thinking: bool = False,
|
||||
) -> BaseLLMProvider:
|
||||
"""
|
||||
Create an LLM provider based on the model name.
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g., "claude-3-opus-20240229", "gpt-4").
|
||||
If not provided, uses SUMMARIZER_MODEL from settings.
|
||||
api_key: Optional API key. If not provided, uses keys from settings.
|
||||
enable_thinking: Enable extended thinking for supported models (Claude Opus 4+, Sonnet 4+, Sonnet 3.7)
|
||||
|
||||
Returns:
|
||||
An initialized LLM provider
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider cannot be determined from the model name
|
||||
"""
|
||||
# Use default model from settings if not provided
|
||||
if model is None:
|
||||
model = settings.SUMMARIZER_MODEL
|
||||
|
||||
provider, model = model.split("/", 1)
|
||||
|
||||
if provider == "anthropic":
|
||||
# Anthropic models
|
||||
if api_key is None:
|
||||
api_key = settings.ANTHROPIC_API_KEY
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_KEY not found in settings. "
|
||||
"Please set it in your .env file."
|
||||
)
|
||||
|
||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||
|
||||
return AnthropicProvider(
|
||||
api_key=api_key, model=model, enable_thinking=enable_thinking
|
||||
)
|
||||
|
||||
# Could add OpenAI support here in the future
|
||||
# elif "gpt" in model_lower or model.startswith("openai"):
|
||||
# ...
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown provider for model: {model}. "
|
||||
f"Supported providers: Anthropic (claude-*)"
|
||||
)
|
||||
388
src/memory/common/llms/openai_provider.py
Normal file
388
src/memory/common/llms/openai_provider.py
Normal file
@ -0,0 +1,388 @@
|
||||
"""OpenAI LLM provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Iterator, Optional
|
||||
|
||||
import openai
|
||||
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageContent,
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolDefinition,
|
||||
ToolResultContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""OpenAI LLM provider with streaming and tool support."""
|
||||
|
||||
def _initialize_client(self) -> openai.OpenAI:
|
||||
"""Initialize the OpenAI client."""
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
|
||||
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert our Message format to OpenAI format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in our format
|
||||
|
||||
Returns:
|
||||
List of messages in OpenAI format
|
||||
"""
|
||||
openai_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
openai_messages.append({"role": msg.role.value, "content": msg.content})
|
||||
else:
|
||||
# Handle multi-part content
|
||||
content_parts = []
|
||||
for item in msg.content:
|
||||
if isinstance(item, TextContent):
|
||||
content_parts.append({"type": "text", "text": item.text})
|
||||
elif isinstance(item, ImageContent):
|
||||
encoded_image = self.encode_image(item.image)
|
||||
image_part: dict[str, Any] = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}"
|
||||
},
|
||||
}
|
||||
if item.detail:
|
||||
image_part["image_url"]["detail"] = item.detail
|
||||
content_parts.append(image_part)
|
||||
elif isinstance(item, ToolUseContent):
|
||||
# OpenAI doesn't have tool_use in content, it's a separate field
|
||||
# We'll handle this by adding a tool_calls field to the message
|
||||
pass
|
||||
elif isinstance(item, ToolResultContent):
|
||||
# OpenAI handles tool results as separate "tool" role messages
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item.tool_use_id,
|
||||
"content": item.content,
|
||||
}
|
||||
)
|
||||
continue
|
||||
elif isinstance(item, ThinkingContent):
|
||||
# OpenAI doesn't have native thinking support in most models
|
||||
# We can add it as text with a special marker
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[Thinking: {item.thinking}]",
|
||||
}
|
||||
)
|
||||
|
||||
# Check if this message has tool calls
|
||||
tool_calls = [
|
||||
item for item in msg.content if isinstance(item, ToolUseContent)
|
||||
]
|
||||
|
||||
message_dict: dict[str, Any] = {"role": msg.role.value}
|
||||
|
||||
if content_parts:
|
||||
message_dict["content"] = content_parts
|
||||
|
||||
if tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {"name": tc.name, "arguments": str(tc.input)},
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
|
||||
openai_messages.append(message_dict)
|
||||
|
||||
return openai_messages
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Optional[list[ToolDefinition]]
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""
|
||||
Convert our tool definitions to OpenAI format.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
List of tools in OpenAI format
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
stream = self.client.chat.completions.create(**kwargs)
|
||||
|
||||
current_tool_call: Optional[dict[str, Any]] = None
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
yield StreamEvent(type="text", data=delta.content)
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.id:
|
||||
# New tool call starting
|
||||
if current_tool_call:
|
||||
# Yield the previous one
|
||||
yield StreamEvent(
|
||||
type="tool_use", data=current_tool_call
|
||||
)
|
||||
current_tool_call = {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name or "",
|
||||
"arguments": tool_call.function.arguments or "",
|
||||
}
|
||||
elif current_tool_call and tool_call.function.arguments:
|
||||
# Continue building the current tool call
|
||||
current_tool_call["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Check if stream is finished
|
||||
if chunk.choices[0].finish_reason:
|
||||
if current_tool_call:
|
||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||
current_tool_call = None
|
||||
|
||||
yield StreamEvent(type="done")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
# Use async client
|
||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = await async_client.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list[ToolDefinition]] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
|
||||
# Use async client
|
||||
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
|
||||
openai_messages = self._convert_messages(messages)
|
||||
|
||||
# Add system prompt as first message if provided
|
||||
if system_prompt:
|
||||
openai_messages.insert(
|
||||
0, {"role": "system", "content": system_prompt}
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": openai_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"top_p": settings.top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if settings.stop_sequences:
|
||||
kwargs["stop"] = settings.stop_sequences
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
stream = await async_client.chat.completions.create(**kwargs)
|
||||
|
||||
current_tool_call: Optional[dict[str, Any]] = None
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
yield StreamEvent(type="text", data=delta.content)
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.id:
|
||||
# New tool call starting
|
||||
if current_tool_call:
|
||||
# Yield the previous one
|
||||
yield StreamEvent(
|
||||
type="tool_use", data=current_tool_call
|
||||
)
|
||||
current_tool_call = {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name or "",
|
||||
"arguments": tool_call.function.arguments or "",
|
||||
}
|
||||
elif current_tool_call and tool_call.function.arguments:
|
||||
# Continue building the current tool call
|
||||
current_tool_call["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Check if stream is finished
|
||||
if chunk.choices[0].finish_reason:
|
||||
if current_tool_call:
|
||||
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||
current_tool_call = None
|
||||
|
||||
yield StreamEvent(type="done")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI streaming error: {e}")
|
||||
yield StreamEvent(type="error", data=str(e))
|
||||
36
src/memory/common/llms/tools/__init__.py
Normal file
36
src/memory/common/llms/tools/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, TypedDict
|
||||
|
||||
|
||||
ToolInput = str | dict[str, Any] | None
|
||||
ToolHandler = Callable[[ToolInput], str]
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
"""A call to a tool."""
|
||||
|
||||
name: str
|
||||
id: str
|
||||
input: ToolInput
|
||||
|
||||
|
||||
class ToolResult(TypedDict):
|
||||
"""A result from a tool call."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
input: ToolInput
|
||||
output: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""Definition of a tool that can be called by the LLM."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any] # JSON Schema for the tool's parameters
|
||||
function: ToolHandler
|
||||
|
||||
def __call__(self, input: ToolInput) -> str:
|
||||
return self.function(input)
|
||||
42
src/memory/common/llms/tools/ping.py
Normal file
42
src/memory/common/llms/tools/ping.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Ping tool for testing LLM tool integration."""
|
||||
|
||||
from memory.common.llms.tools import ToolDefinition, ToolInput
|
||||
|
||||
|
||||
def handle_ping_call(message: ToolInput = None) -> str:
|
||||
"""
|
||||
Handle a ping tool call.
|
||||
|
||||
Args:
|
||||
message: Optional message to include in response
|
||||
|
||||
Returns:
|
||||
Response string
|
||||
"""
|
||||
if message:
|
||||
return f"pong: {message}"
|
||||
return "pong"
|
||||
|
||||
|
||||
def get_ping_tool() -> ToolDefinition:
|
||||
"""
|
||||
Get a ping tool definition for testing tool calls.
|
||||
|
||||
Returns a simple tool that takes no required parameters and can be used
|
||||
to verify that tool calling is working correctly.
|
||||
"""
|
||||
return ToolDefinition(
|
||||
name="ping",
|
||||
description="A simple test tool that returns 'pong'. Use this to verify tool calling is working.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Optional message to echo back",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
function=handle_ping_call,
|
||||
)
|
||||
9
src/memory/common/messages.py
Normal file
9
src/memory/common/messages.py
Normal file
@ -0,0 +1,9 @@
|
||||
def process_message(
|
||||
msg: str,
|
||||
history: list[str],
|
||||
model: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
disallowed_tools: list[str] | None = None,
|
||||
) -> str:
|
||||
return "asd"
|
||||
@ -172,3 +172,12 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||
)
|
||||
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
||||
|
||||
# Discord collector settings
|
||||
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
|
||||
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
|
||||
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
|
||||
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8000))
|
||||
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
|
||||
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
||||
|
||||
@ -105,7 +105,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||
|
||||
try:
|
||||
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||
response = llms.summarize(prompt, settings.SUMMARIZER_MODEL)
|
||||
result = parse_response(response)
|
||||
|
||||
summary = result.get("summary", "")
|
||||
|
||||
166
src/memory/discord/api.py
Normal file
166
src/memory/discord/api.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
Discord API server.
|
||||
|
||||
FastAPI server that owns and manages a Discord collector instance,
|
||||
providing HTTP endpoints for sending Discord messages.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
from memory.common import settings
|
||||
from memory.discord.collector import MessageCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SendDMRequest(BaseModel):
|
||||
user: str # Discord user ID or username
|
||||
message: str
|
||||
|
||||
|
||||
class SendChannelRequest(BaseModel):
|
||||
channel_name: str # Channel name (e.g., "memory-errors")
|
||||
message: str
|
||||
|
||||
|
||||
# Application state
|
||||
class AppState:
|
||||
def __init__(self):
|
||||
self.collector: MessageCollector | None = None
|
||||
self.collector_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
app_state = AppState()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage Discord collector lifecycle"""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||
return
|
||||
|
||||
# Create and start the collector
|
||||
app_state.collector = MessageCollector()
|
||||
app_state.collector_task = asyncio.create_task(
|
||||
app_state.collector.start(settings.DISCORD_BOT_TOKEN)
|
||||
)
|
||||
logger.info("Discord collector started")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
if app_state.collector and not app_state.collector.is_closed():
|
||||
await app_state.collector.close()
|
||||
|
||||
if app_state.collector_task:
|
||||
app_state.collector_task.cancel()
|
||||
try:
|
||||
await app_state.collector_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Discord collector stopped")
|
||||
|
||||
|
||||
# FastAPI app with lifespan management
|
||||
app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.post("/send_dm")
|
||||
async def send_dm_endpoint(request: SendDMRequest):
|
||||
"""Send a DM via the collector's Discord client"""
|
||||
if not app_state.collector:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
try:
|
||||
success = await app_state.collector.send_dm(request.user, request.message)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to send DM to {request.user}",
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"DM sent to {request.user}",
|
||||
"user": request.user,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send DM: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/send_channel")
|
||||
async def send_channel_endpoint(request: SendChannelRequest):
|
||||
"""Send a message to a channel via the collector's Discord client"""
|
||||
if not app_state.collector:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
try:
|
||||
success = await app_state.collector.send_to_channel(
|
||||
request.channel_name, request.message
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Message sent to channel {request.channel_name}",
|
||||
"channel": request.channel_name,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to send message to channel {request.channel_name}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send channel message: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Check if the Discord collector is running and healthy"""
|
||||
if not app_state.collector:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
collector = app_state.collector
|
||||
return {
|
||||
"status": "healthy",
|
||||
"connected": not collector.is_closed(),
|
||||
"user": str(collector.user) if collector.user else None,
|
||||
"guilds": len(collector.guilds) if collector.guilds else 0,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/refresh_metadata")
|
||||
async def refresh_metadata():
|
||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||
if not app_state.collector:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
try:
|
||||
result = await app_state.collector.refresh_metadata()
|
||||
return {"success": True, "message": "Metadata refreshed successfully", **result}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh metadata: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
||||
"""Run the Discord API server"""
|
||||
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# For testing the API server standalone
|
||||
host = settings.DISCORD_COLLECTOR_SERVER_URL
|
||||
port = settings.DISCORD_COLLECTOR_PORT
|
||||
run_discord_api_server(host, port)
|
||||
410
src/memory/discord/collector.py
Normal file
410
src/memory/discord/collector.py
Normal file
@ -0,0 +1,410 @@
|
||||
"""
|
||||
Discord message collector.
|
||||
|
||||
Core message collection functionality - stores Discord messages to database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.sources import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Pure functions for Discord entity creation/updates
|
||||
def create_or_update_server(
|
||||
session: Session | scoped_session, guild: discord.Guild | None
|
||||
) -> DiscordServer | None:
|
||||
"""Get or create DiscordServer record (pure DB operation)"""
|
||||
if not guild:
|
||||
return None
|
||||
|
||||
server = session.query(DiscordServer).get(guild.id)
|
||||
|
||||
if not server:
|
||||
server = DiscordServer(
|
||||
id=guild.id,
|
||||
name=guild.name,
|
||||
description=guild.description,
|
||||
member_count=guild.member_count,
|
||||
)
|
||||
session.add(server)
|
||||
session.flush() # Get the ID
|
||||
logger.info(f"Created server record for {guild.name} ({guild.id})")
|
||||
else:
|
||||
# Update metadata
|
||||
server.name = guild.name
|
||||
server.description = guild.description
|
||||
server.member_count = guild.member_count
|
||||
server.last_sync_at = datetime.now(timezone.utc)
|
||||
|
||||
return server
|
||||
|
||||
|
||||
def determine_channel_metadata(channel) -> tuple[str, int | None, str]:
|
||||
"""Pure function to determine channel type, server_id, and name"""
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
desc = (
|
||||
f"DM with {channel.recipient.name}" if channel.recipient else "Unknown DM"
|
||||
)
|
||||
return ("dm", None, desc)
|
||||
elif isinstance(channel, discord.GroupChannel):
|
||||
return "group_dm", None, channel.name or "Group DM"
|
||||
elif isinstance(
|
||||
channel, (discord.TextChannel, discord.VoiceChannel, discord.Thread)
|
||||
):
|
||||
return (
|
||||
channel.__class__.__name__.lower().replace("channel", ""),
|
||||
channel.guild.id,
|
||||
channel.name,
|
||||
)
|
||||
else:
|
||||
guild = getattr(channel, "guild", None)
|
||||
server_id = guild.id if guild else None
|
||||
name = getattr(channel, "name", f"Unknown-{channel.id}")
|
||||
return "unknown", server_id, name
|
||||
|
||||
|
||||
def create_or_update_channel(
|
||||
session: Session | scoped_session, channel
|
||||
) -> DiscordChannel | None:
|
||||
"""Get or create DiscordChannel record (pure DB operation)"""
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
discord_channel = session.query(DiscordChannel).get(channel.id)
|
||||
|
||||
if not discord_channel:
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
discord_channel = DiscordChannel(
|
||||
id=channel.id,
|
||||
server_id=server_id,
|
||||
name=name,
|
||||
channel_type=channel_type,
|
||||
)
|
||||
session.add(discord_channel)
|
||||
session.flush()
|
||||
logger.debug(f"Created channel: {name}")
|
||||
elif hasattr(channel, "name"):
|
||||
discord_channel.name = channel.name
|
||||
|
||||
return discord_channel
|
||||
|
||||
|
||||
def create_or_update_user(
|
||||
session: Session | scoped_session, user: discord.User | discord.Member
|
||||
) -> DiscordUser:
|
||||
"""Get or create DiscordUser record (pure DB operation)"""
|
||||
if not user:
|
||||
return None
|
||||
|
||||
discord_user = session.query(DiscordUser).get(user.id)
|
||||
|
||||
if not discord_user:
|
||||
discord_user = DiscordUser(
|
||||
id=user.id,
|
||||
username=user.name,
|
||||
display_name=user.display_name,
|
||||
)
|
||||
session.add(discord_user)
|
||||
session.flush()
|
||||
logger.debug(f"Created user: {user.name}")
|
||||
else:
|
||||
# Update user info in case it changed
|
||||
discord_user.username = user.name
|
||||
discord_user.display_name = user.display_name
|
||||
|
||||
return discord_user
|
||||
|
||||
|
||||
def determine_message_metadata(
|
||||
message: discord.Message,
|
||||
) -> tuple[str, int | None, int | None]:
|
||||
"""Pure function to determine message type, reply_to_id, and thread_id"""
|
||||
message_type = "default"
|
||||
reply_to_id = None
|
||||
thread_id = None
|
||||
|
||||
if message.reference and message.reference.message_id:
|
||||
message_type = "reply"
|
||||
reply_to_id = message.reference.message_id
|
||||
|
||||
if hasattr(message.channel, "parent") and message.channel.parent:
|
||||
thread_id = message.channel.id
|
||||
|
||||
return message_type, reply_to_id, thread_id
|
||||
|
||||
|
||||
def should_track_message(
|
||||
server: DiscordServer | None,
|
||||
channel: DiscordChannel,
|
||||
user: DiscordUser,
|
||||
) -> bool:
|
||||
"""Pure function to determine if we should track this message"""
|
||||
if server and not server.track_messages: # type: ignore
|
||||
return False
|
||||
|
||||
if not channel.track_messages:
|
||||
return False
|
||||
|
||||
if channel.channel_type in ("dm", "group_dm"):
|
||||
return bool(user.track_messages)
|
||||
|
||||
# Default: track the message
|
||||
return True
|
||||
|
||||
|
||||
def should_collect_bot_message(message: discord.Message) -> bool:
|
||||
"""Pure function to determine if we should collect bot messages"""
|
||||
return not message.author.bot or settings.DISCORD_COLLECT_BOTS
|
||||
|
||||
|
||||
def sync_guild_metadata(guild: discord.Guild) -> None:
|
||||
"""Sync a single guild's metadata (functional approach)"""
|
||||
with make_session() as session:
|
||||
create_or_update_server(session, guild)
|
||||
|
||||
for channel in guild.channels:
|
||||
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||
create_or_update_channel(session, channel)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
class MessageCollector(commands.Bot):
|
||||
"""Discord bot that collects and stores messages (thin event handler)"""
|
||||
|
||||
def __init__(self):
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
intents.members = True
|
||||
intents.dm_messages = True
|
||||
|
||||
super().__init__(
|
||||
command_prefix="!memory_", # Prefix to avoid conflicts
|
||||
intents=intents,
|
||||
help_command=None, # Disable default help
|
||||
)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Called when bot connects to Discord"""
|
||||
logger.info(f"Discord collector connected as {self.user}")
|
||||
logger.info(f"Connected to {len(self.guilds)} servers")
|
||||
|
||||
# Sync server and channel metadata
|
||||
await self.sync_servers_and_channels()
|
||||
|
||||
logger.info("Discord message collector ready")
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""Queue incoming message for database storage"""
|
||||
try:
|
||||
if should_collect_bot_message(message):
|
||||
# Ensure Discord entities exist in database first
|
||||
with make_session() as session:
|
||||
create_or_update_user(session, message.author)
|
||||
create_or_update_channel(session, message.channel)
|
||||
if message.guild:
|
||||
create_or_update_server(session, message.guild)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Queue the message for processing
|
||||
add_discord_message.delay(
|
||||
message_id=message.id,
|
||||
channel_id=message.channel.id,
|
||||
author_id=message.author.id,
|
||||
server_id=message.guild.id if message.guild else None,
|
||||
content=message.content or "",
|
||||
sent_at=message.created_at.isoformat(),
|
||||
message_reference_id=message.reference.message_id
|
||||
if message.reference
|
||||
else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error queuing message {message.id}: {e}")
|
||||
|
||||
async def on_message_edit(self, before: discord.Message, after: discord.Message):
|
||||
"""Queue message edit for database update"""
|
||||
try:
|
||||
edit_time = after.edited_at or datetime.now(timezone.utc)
|
||||
edit_discord_message.delay(
|
||||
message_id=after.id,
|
||||
content=after.content,
|
||||
edited_at=edit_time.isoformat(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error queuing message edit {after.id}: {e}")
|
||||
|
||||
async def sync_servers_and_channels(self):
|
||||
"""Sync server and channel metadata on startup"""
|
||||
for guild in self.guilds:
|
||||
sync_guild_metadata(guild)
|
||||
|
||||
logger.info(f"Synced {len(self.guilds)} servers and their channels")
|
||||
|
||||
async def refresh_metadata(self) -> dict[str, int]:
|
||||
"""Refresh server and channel metadata from Discord and update database"""
|
||||
print("🔄 Refreshing Discord metadata...")
|
||||
|
||||
servers_updated = 0
|
||||
channels_updated = 0
|
||||
users_updated = 0
|
||||
|
||||
with make_session() as session:
|
||||
# Refresh all servers
|
||||
for guild in self.guilds:
|
||||
create_or_update_server(session, guild)
|
||||
servers_updated += 1
|
||||
|
||||
# Refresh all channels in this server
|
||||
for channel in guild.channels:
|
||||
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||
create_or_update_channel(session, channel)
|
||||
channels_updated += 1
|
||||
|
||||
# Refresh all members in this server (if members intent is enabled)
|
||||
if self.intents.members:
|
||||
for member in guild.members:
|
||||
create_or_update_user(session, member)
|
||||
users_updated += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
result = {
|
||||
"servers_updated": servers_updated,
|
||||
"channels_updated": channels_updated,
|
||||
"users_updated": users_updated,
|
||||
}
|
||||
|
||||
print(f"✅ Metadata refresh complete: {result}")
|
||||
logger.info(f"Metadata refresh complete: {result}")
|
||||
|
||||
return result
|
||||
|
||||
async def get_user(self, user_identifier: int | str) -> discord.User | None:
|
||||
"""Get a Discord user by ID or username"""
|
||||
if isinstance(user_identifier, int):
|
||||
# Direct user ID lookup
|
||||
if user := super().get_user(user_identifier):
|
||||
return user
|
||||
try:
|
||||
return await self.fetch_user(user_identifier)
|
||||
except discord.NotFound:
|
||||
return None
|
||||
else:
|
||||
# Username lookup - search through all guilds
|
||||
for guild in self.guilds:
|
||||
for member in guild.members:
|
||||
if (
|
||||
member.name == user_identifier
|
||||
or member.display_name == user_identifier
|
||||
or f"{member.name}#{member.discriminator}" == user_identifier
|
||||
):
|
||||
return member
|
||||
return None
|
||||
|
||||
async def get_channel_by_name(
|
||||
self, channel_name: str
|
||||
) -> discord.TextChannel | None:
|
||||
"""Get a Discord channel by name (does not create if missing)"""
|
||||
# Search all guilds for the channel
|
||||
for guild in self.guilds:
|
||||
for ch in guild.channels:
|
||||
if isinstance(ch, discord.TextChannel) and ch.name == channel_name:
|
||||
return ch
|
||||
return None
|
||||
|
||||
async def create_channel(
|
||||
self, channel_name: str, guild_id: int | None = None
|
||||
) -> discord.TextChannel | None:
|
||||
"""Create a Discord channel in the specified guild (or first guild if none specified)"""
|
||||
target_guild = None
|
||||
|
||||
if guild_id:
|
||||
target_guild = self.get_guild(guild_id)
|
||||
elif self.guilds:
|
||||
target_guild = self.guilds[0] # Default to first guild
|
||||
|
||||
if not target_guild:
|
||||
logger.error(f"No guild available to create channel {channel_name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
channel = await target_guild.create_text_channel(channel_name)
|
||||
logger.info(f"Created channel {channel_name} in {target_guild.name}")
|
||||
return channel
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create channel {channel_name} in {target_guild.name}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def send_dm(self, user_identifier: int | str, message: str) -> bool:
|
||||
"""Send a DM using this collector's Discord client"""
|
||||
try:
|
||||
user = await self.get_user(user_identifier)
|
||||
if not user:
|
||||
logger.error(f"User {user_identifier} not found")
|
||||
return False
|
||||
|
||||
await user.send(message)
|
||||
logger.info(f"Sent DM to {user_identifier}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||
return False
|
||||
|
||||
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
||||
"""Send a message to a channel by name across all guilds"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
channel = await self.get_channel_by_name(channel_name)
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_name} not found")
|
||||
return False
|
||||
|
||||
await channel.send(message)
|
||||
logger.info(f"Sent message to channel {channel_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def run_collector():
|
||||
"""Run the Discord message collector"""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||
return
|
||||
|
||||
collector = MessageCollector()
|
||||
|
||||
try:
|
||||
await collector.start(settings.DISCORD_BOT_TOKEN)
|
||||
except Exception as e:
|
||||
logger.error(f"Discord collector failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run_collector())
|
||||
@ -6,6 +6,7 @@ from memory.workers.tasks import (
|
||||
email,
|
||||
comic,
|
||||
blogs,
|
||||
discord,
|
||||
ebook,
|
||||
forums,
|
||||
maintenance,
|
||||
@ -20,6 +21,7 @@ __all__ = [
|
||||
"comic",
|
||||
"blogs",
|
||||
"ebook",
|
||||
"discord",
|
||||
"forums",
|
||||
"maintenance",
|
||||
"notes",
|
||||
|
||||
@ -115,6 +115,9 @@ def sync_article_feed(feed_id: int) -> dict:
|
||||
|
||||
try:
|
||||
for feed_item in parser.parse_feed():
|
||||
if not feed_item.url:
|
||||
continue
|
||||
|
||||
articles_found += 1
|
||||
|
||||
existing = check_content_exists(session, BlogPost, url=feed_item.url)
|
||||
|
||||
@ -10,9 +10,9 @@ from collections import defaultdict
|
||||
import hashlib
|
||||
import traceback
|
||||
import logging
|
||||
from typing import Any, Callable, Iterable, Sequence, cast
|
||||
from typing import Any, Callable, Sequence, cast
|
||||
|
||||
from memory.common import embedding, qdrant, settings
|
||||
from memory.common import embedding, qdrant
|
||||
from memory.common.db.models import SourceItem, Chunk
|
||||
from memory.common.discord import notify_task_failure
|
||||
|
||||
@ -38,19 +38,12 @@ def check_content_exists(
|
||||
Returns:
|
||||
Existing SourceItem if found, None otherwise
|
||||
"""
|
||||
query = session.query(model_class)
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(model_class, key):
|
||||
continue
|
||||
if hasattr(model_class, key):
|
||||
query = query.filter(getattr(model_class, key) == value)
|
||||
|
||||
existing = (
|
||||
session.query(model_class)
|
||||
.filter(getattr(model_class, key) == value)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
return None
|
||||
return query.first()
|
||||
|
||||
|
||||
def create_content_hash(content: str, *additional_data: str) -> bytes:
|
||||
@ -286,6 +279,6 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
|
||||
return {"status": "error", "error": str(e)}
|
||||
return {"status": "error", "error": str(e), "traceback": traceback_str}
|
||||
|
||||
return wrapper
|
||||
|
||||
166
src/memory/workers/tasks/discord.py
Normal file
166
src/memory/workers/tasks/discord.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
Celery tasks for Discord message processing.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from memory.common.celery_app import app
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordMessage, DiscordUser
|
||||
from memory.workers.tasks.content_processing import (
|
||||
safe_task_execution,
|
||||
check_content_exists,
|
||||
create_task_result,
|
||||
process_content_item,
|
||||
)
|
||||
from memory.common.celery_app import (
|
||||
ADD_DISCORD_MESSAGE,
|
||||
EDIT_DISCORD_MESSAGE,
|
||||
PROCESS_DISCORD_MESSAGE,
|
||||
)
|
||||
from memory.common import settings
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_prev(
|
||||
session: Session | scoped_session, channel_id: int, sent_at: datetime
|
||||
) -> list[str]:
|
||||
prev = (
|
||||
session.query(DiscordUser.username, DiscordMessage.content)
|
||||
.join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
|
||||
.filter(
|
||||
DiscordMessage.channel_id == channel_id,
|
||||
DiscordMessage.sent_at < sent_at,
|
||||
)
|
||||
.order_by(DiscordMessage.sent_at.desc())
|
||||
.limit(settings.DISCORD_CONTEXT_WINDOW)
|
||||
.all()
|
||||
)
|
||||
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||
|
||||
|
||||
def should_process(message: DiscordMessage) -> bool:
|
||||
return (
|
||||
settings.DISCORD_PROCESS_MESSAGES
|
||||
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
||||
and not (
|
||||
(message.server and message.server.ignore_messages)
|
||||
or (message.channel and message.channel.ignore_messages)
|
||||
or (message.discord_user and message.discord_user.ignore_messages)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||
@safe_task_execution
|
||||
def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
logger.info(f"Processing Discord message {message_id}")
|
||||
|
||||
with make_session() as session:
|
||||
discord_message = session.query(DiscordMessage).get(message_id)
|
||||
if not discord_message:
|
||||
logger.info(f"Discord message not found: {message_id}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Message not found",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
print("Processing message", discord_message.id, discord_message.content)
|
||||
|
||||
return {
|
||||
"status": "processed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=ADD_DISCORD_MESSAGE)
|
||||
@safe_task_execution
|
||||
def add_discord_message(
|
||||
message_id: int,
|
||||
channel_id: int,
|
||||
author_id: int,
|
||||
content: str,
|
||||
sent_at: str,
|
||||
server_id: int | None = None,
|
||||
message_reference_id: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Add a Discord message to the database.
|
||||
|
||||
This task is queued by the Discord collector when messages are received.
|
||||
"""
|
||||
logger.info(f"Adding Discord message {message_id}: {content}")
|
||||
# Include message_id in hash to ensure uniqueness across duplicate content
|
||||
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
|
||||
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
|
||||
|
||||
with make_session() as session:
|
||||
discord_message = DiscordMessage(
|
||||
modality="text",
|
||||
sha256=content_hash,
|
||||
content=content,
|
||||
channel_id=channel_id,
|
||||
sent_at=sent_at_dt,
|
||||
server_id=server_id,
|
||||
discord_user_id=author_id,
|
||||
message_id=message_id,
|
||||
message_type="reply" if message_reference_id else "default",
|
||||
reply_to_message_id=message_reference_id,
|
||||
)
|
||||
existing_msg = check_content_exists(
|
||||
session, DiscordMessage, message_id=message_id, sha256=content_hash
|
||||
)
|
||||
if existing_msg:
|
||||
logger.info(f"Discord message already exists: {existing_msg.message_id}")
|
||||
return create_task_result(
|
||||
existing_msg, "already_exists", message_id=message_id
|
||||
)
|
||||
|
||||
if channel_id:
|
||||
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
|
||||
|
||||
result = process_content_item(discord_message, session)
|
||||
if should_process(discord_message):
|
||||
process_discord_message.delay(discord_message.id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.task(name=EDIT_DISCORD_MESSAGE)
|
||||
@safe_task_execution
|
||||
def edit_discord_message(
|
||||
message_id: int, content: str, edited_at: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Edit a Discord message in the database.
|
||||
|
||||
This task is queued by the Discord collector when messages are edited.
|
||||
"""
|
||||
logger.info(f"Editing Discord message {message_id}: {content}")
|
||||
with make_session() as session:
|
||||
existing_msg = check_content_exists(
|
||||
session, DiscordMessage, message_id=message_id
|
||||
)
|
||||
if not existing_msg:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Message not found",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
existing_msg.content = content # type: ignore
|
||||
if existing_msg.channel_id:
|
||||
existing_msg.messages_before = get_prev(
|
||||
session, existing_msg.channel_id, existing_msg.sent_at
|
||||
)
|
||||
existing_msg.edited_at = datetime.fromisoformat(
|
||||
edited_at.replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
return process_content_item(existing_msg, session)
|
||||
@ -88,11 +88,10 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
||||
|
||||
# Make the LLM call
|
||||
if scheduled_call.model:
|
||||
response = llms.call(
|
||||
response = llms.summarize(
|
||||
prompt=cast(str, scheduled_call.message),
|
||||
model=cast(str, scheduled_call.model),
|
||||
system_prompt=cast(str, scheduled_call.system_prompt)
|
||||
or llms.SYSTEM_PROMPT,
|
||||
system_prompt=cast(str, scheduled_call.system_prompt),
|
||||
)
|
||||
else:
|
||||
response = cast(str, scheduled_call.message)
|
||||
|
||||
@ -273,6 +273,27 @@ def mock_anthropic_client():
|
||||
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
||||
client = mock_client()
|
||||
client.messages = Mock()
|
||||
|
||||
# Mock stream as a context manager
|
||||
mock_stream = Mock()
|
||||
mock_stream.__enter__ = Mock(
|
||||
return_value=Mock(
|
||||
__iter__=lambda self: iter(
|
||||
[
|
||||
Mock(
|
||||
type="content_block_delta",
|
||||
delta=Mock(
|
||||
type="text_delta",
|
||||
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
mock_stream.__exit__ = Mock(return_value=False)
|
||||
client.messages.stream = Mock(return_value=mock_stream)
|
||||
|
||||
client.messages.create = Mock(
|
||||
return_value=Mock(
|
||||
content=[
|
||||
|
||||
@ -2,318 +2,250 @@ import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import requests
|
||||
|
||||
from memory.common import discord, settings
|
||||
from memory.common import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_request():
|
||||
with patch("requests.Session.request") as mock:
|
||||
yield mock
|
||||
def mock_api_url():
|
||||
"""Mock the API URL to avoid using actual settings"""
|
||||
with patch(
|
||||
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_channels_response():
|
||||
return [
|
||||
{"name": "memory-errors", "id": "error_channel_id"},
|
||||
{"name": "memory-activity", "id": "activity_channel_id"},
|
||||
{"name": "memory-discoveries", "id": "discovery_channel_id"},
|
||||
{"name": "memory-chat", "id": "chat_channel_id"},
|
||||
]
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
|
||||
def test_get_api_url():
|
||||
"""Test API URL construction"""
|
||||
assert discord.get_api_url() == "http://testhost:9999"
|
||||
|
||||
|
||||
def test_discord_server_init(mock_session_request, mock_get_channels_response):
|
||||
# Mock the channels API call
|
||||
@patch("requests.post")
|
||||
def test_send_dm_success(mock_post, mock_api_url):
|
||||
"""Test successful DM sending"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_get_channels_response
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert server.server_id == "server123"
|
||||
assert server.server_name == "Test Server"
|
||||
assert hasattr(server, "channels")
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
|
||||
# Mock the channels API call
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_get_channels_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
|
||||
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
|
||||
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
|
||||
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
|
||||
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
|
||||
def test_setup_channels_create_missing(mock_session_request):
|
||||
# Mock get channels (empty) and create channel calls
|
||||
get_response = Mock()
|
||||
get_response.json.return_value = []
|
||||
get_response.raise_for_status.return_value = None
|
||||
|
||||
create_response = Mock()
|
||||
create_response.json.return_value = {"id": "new_channel_id"}
|
||||
create_response.raise_for_status.return_value = None
|
||||
|
||||
mock_session_request.side_effect = [
|
||||
get_response,
|
||||
create_response,
|
||||
create_response,
|
||||
create_response,
|
||||
create_response,
|
||||
]
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
|
||||
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
|
||||
|
||||
|
||||
def test_channel_properties():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {
|
||||
discord.ERROR_CHANNEL: "error_id",
|
||||
discord.ACTIVITY_CHANNEL: "activity_id",
|
||||
discord.DISCOVERY_CHANNEL: "discovery_id",
|
||||
discord.CHAT_CHANNEL: "chat_id",
|
||||
}
|
||||
|
||||
assert server.error_channel == "error_id"
|
||||
assert server.activity_channel == "activity_id"
|
||||
assert server.discovery_channel == "discovery_id"
|
||||
assert server.chat_channel == "chat_id"
|
||||
|
||||
|
||||
def test_channel_id_exists():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {"test-channel": "channel123"}
|
||||
|
||||
assert server.channel_id("test-channel") == "channel123"
|
||||
|
||||
|
||||
def test_channel_id_not_found():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {}
|
||||
|
||||
with pytest.raises(ValueError, match="Channel nonexistent not found"):
|
||||
server.channel_id("nonexistent")
|
||||
|
||||
|
||||
def test_send_message(mock_session_request):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
|
||||
server.send_message("channel123", "Hello World")
|
||||
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/channels/channel123/messages",
|
||||
data=None,
|
||||
json={"content": "Hello World"},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user_identifier": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
def test_create_channel(mock_session_request):
|
||||
@patch("requests.post")
|
||||
def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
"""Test DM sending when API returns failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"id": "new_channel_id"}
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
channel_id = server.create_channel("new-channel")
|
||||
|
||||
assert channel_id == "new_channel_id"
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/guilds/server123/channels",
|
||||
data=None,
|
||||
json={"name": "new-channel", "type": 0},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_create_channel_custom_type(mock_session_request):
|
||||
@patch("requests.post")
|
||||
def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
"""Test DM sending when HTTP error occurs"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"id": "voice_channel_id"}
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"""Test successful channel message broadcast"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
channel_id = server.create_channel("voice-channel", channel_type=2)
|
||||
|
||||
assert channel_id == "voice_channel_id"
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/guilds/server123/channels",
|
||||
data=None,
|
||||
json={"name": "voice-channel", "type": 2},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
def test_str_representation():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
server.server_name = "Test Server"
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
|
||||
def test_request_adds_headers(mock_session_request):
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
server.request("GET", "https://example.com", headers={"Custom": "header"})
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
expected_headers = {
|
||||
"Custom": "header",
|
||||
"Authorization": "Bot test_token_123",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
mock_session_request.assert_called_once_with(
|
||||
"GET", "https://example.com", headers=expected_headers
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_channels_url():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
|
||||
assert (
|
||||
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
@patch("requests.get")
|
||||
def test_get_bot_servers_success(mock_get):
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = [
|
||||
{"id": "server1", "name": "Server 1"},
|
||||
{"id": "server2", "name": "Server 2"},
|
||||
]
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
servers = discord.get_bot_servers()
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert len(servers) == 2
|
||||
assert servers[0] == {"id": "server1", "name": "Server 1"}
|
||||
mock_get.assert_called_once_with(
|
||||
"https://discord.com/api/v10/users/@me/guilds",
|
||||
headers={"Authorization": "Bot test_token"},
|
||||
)
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||
def test_get_bot_servers_no_token():
|
||||
assert discord.get_bot_servers() == []
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
@patch("requests.get")
|
||||
def test_get_bot_servers_exception(mock_get):
|
||||
mock_get.side_effect = requests.RequestException("API Error")
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
servers = discord.get_bot_servers()
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert servers == []
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.discord.get_bot_servers")
|
||||
@patch("memory.common.discord.DiscordServer")
|
||||
def test_load_servers(mock_discord_server_class, mock_get_servers):
|
||||
mock_get_servers.return_value = [
|
||||
{"id": "server1", "name": "Server 1"},
|
||||
{"id": "server2", "name": "Server 2"},
|
||||
]
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
discord.load_servers()
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert mock_discord_server_class.call_count == 2
|
||||
mock_discord_server_class.assert_any_call("server1", "Server 1")
|
||||
mock_discord_server_class.assert_any_call("server2", "Server 2")
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_broadcast_message():
|
||||
mock_server1 = Mock()
|
||||
mock_server2 = Mock()
|
||||
discord.servers = {"1": mock_server1, "2": mock_server2}
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
|
||||
"""Test successful metadata refresh"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": 5,
|
||||
"channels": 20,
|
||||
"users": 100,
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
discord.broadcast_message("test-channel", "Hello")
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
mock_server1.send_message.assert_called_once_with(
|
||||
mock_server1.channel_id.return_value, "Hello"
|
||||
)
|
||||
mock_server2.send_message.assert_called_once_with(
|
||||
mock_server2.channel_id.return_value, "Hello"
|
||||
assert result == {"servers": 5, "channels": 20, "users": 100}
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/refresh_metadata", timeout=30
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_broadcast_message_disabled():
|
||||
mock_server = Mock()
|
||||
discord.servers = {"1": mock_server}
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
|
||||
"""Test metadata refresh failure"""
|
||||
mock_post.side_effect = requests.RequestException("Failed to connect")
|
||||
|
||||
discord.broadcast_message("test-channel", "Hello")
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
mock_server.send_message.assert_not_called()
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
|
||||
"""Test metadata refresh with HTTP error"""
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
|
||||
def test_send_error_message(mock_broadcast):
|
||||
discord.send_error_message("Error occurred")
|
||||
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
|
||||
def test_send_activity_message(mock_broadcast):
|
||||
discord.send_activity_message("Activity update")
|
||||
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
|
||||
def test_send_discovery_message(mock_broadcast):
|
||||
discord.send_discovery_message("Discovery made")
|
||||
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
|
||||
def test_send_chat_message(mock_broadcast):
|
||||
discord.send_chat_message("Chat message")
|
||||
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
@ -323,69 +255,181 @@ def test_notify_task_failure_basic(mock_send_error):
|
||||
assert "**Error:** Something went wrong" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_args(mock_send_error):
|
||||
"""Test task failure notification with arguments"""
|
||||
discord.notify_task_failure(
|
||||
"test_task",
|
||||
"Error message",
|
||||
task_args=("arg1", "arg2"),
|
||||
task_kwargs={"key": "value"},
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Args:** `('arg1', 'arg2')`" in message
|
||||
assert "**Kwargs:** `{'key': 'value'}`" in message
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert traceback in message
|
||||
assert "Exception: test" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
long_error = "x" * 600 # Longer than 500 char limit
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert long_error[:500] in message
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
# The full 600-char error should not be present
|
||||
error_section = message.split("**Error:** ")[1].split("\n")[0]
|
||||
assert len(error_section) == 500
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
long_traceback = "x" * 1000 # Longer than 800 char limit
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
# The full 1000-char traceback should not be present
|
||||
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
|
||||
assert len(traceback_section) == 800
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
|
||||
) # Some buffer for formatting
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
discord.notify_task_failure("test_task", "Error message")
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_send_fails(mock_send_error):
|
||||
mock_send_error.side_effect = Exception("Discord API error")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
"""Test that exceptions in send_error_message don't propagate"""
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise, just log the error
|
||||
discord.notify_task_failure("test_task", "Error message")
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,channel_setting,message",
|
||||
[
|
||||
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
|
||||
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
|
||||
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
|
||||
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
|
||||
],
|
||||
)
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_convenience_functions_use_correct_channels(
|
||||
mock_broadcast, function, channel_setting, message
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
"""Test sending DM with special characters"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
"""Test broadcasting a long message"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
"""Test health check when response doesn't have status key"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
||||
|
||||
435
tests/memory/common/test_discord_integration.py
Normal file
435
tests/memory/common/test_discord_integration.py
Normal file
@ -0,0 +1,435 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import requests
|
||||
|
||||
from memory.common import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_url():
|
||||
"""Mock the API URL to avoid using actual settings"""
|
||||
with patch(
|
||||
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
|
||||
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
|
||||
def test_get_api_url():
|
||||
"""Test API URL construction"""
|
||||
assert discord.get_api_url() == "http://testhost:9999"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_success(mock_post, mock_api_url):
|
||||
"""Test successful DM sending"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user_identifier": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
"""Test DM sending when API returns failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
"""Test DM sending when HTTP error occurs"""
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"""Test successful channel message broadcast"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast failure"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": False}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
|
||||
"""Test successful metadata refresh"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": 5,
|
||||
"channels": 20,
|
||||
"users": 100,
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result == {"servers": 5, "channels": 20, "users": 100}
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/refresh_metadata", timeout=30
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
|
||||
"""Test metadata refresh failure"""
|
||||
mock_post.side_effect = requests.RequestException("Failed to connect")
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
|
||||
"""Test metadata refresh with HTTP error"""
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.refresh_discord_metadata()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
|
||||
def test_send_error_message(mock_broadcast):
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
|
||||
def test_send_activity_message(mock_broadcast):
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
|
||||
def test_send_discovery_message(mock_broadcast):
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
|
||||
def test_send_chat_message(mock_broadcast):
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "🚨 **Task Failed: test_task**" in message
|
||||
assert "**Error:** Something went wrong" in message
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_args(mock_send_error):
|
||||
"""Test task failure notification with arguments"""
|
||||
discord.notify_task_failure(
|
||||
"test_task",
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert "Exception: test" in message
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
# The full 600-char error should not be present
|
||||
error_section = message.split("**Error:** ")[1].split("\n")[0]
|
||||
assert len(error_section) == 500
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
# The full 1000-char traceback should not be present
|
||||
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
|
||||
assert len(traceback_section) == 800
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
|
||||
) # Some buffer for formatting
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
"""Test that exceptions in send_error_message don't propagate"""
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,channel_setting,message",
|
||||
[
|
||||
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
|
||||
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
|
||||
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
|
||||
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
|
||||
],
|
||||
)
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_convenience_functions_use_correct_channels(
|
||||
mock_broadcast, function, channel_setting, message
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
"""Test sending DM with special characters"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
"""Test broadcasting a long message"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"success": True}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
"""Test health check when response doesn't have status key"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
|
||||
assert result is False
|
||||
840
tests/memory/discord/test_collector.py
Normal file
840
tests/memory/discord/test_collector.py
Normal file
@ -0,0 +1,840 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
|
||||
import discord
|
||||
|
||||
from memory.discord.collector import (
|
||||
create_or_update_server,
|
||||
determine_channel_metadata,
|
||||
create_or_update_channel,
|
||||
create_or_update_user,
|
||||
determine_message_metadata,
|
||||
should_track_message,
|
||||
should_collect_bot_message,
|
||||
sync_guild_metadata,
|
||||
MessageCollector,
|
||||
)
|
||||
from memory.common.db.models.sources import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures for Discord objects
|
||||
@pytest.fixture
|
||||
def mock_guild():
|
||||
"""Mock Discord Guild object"""
|
||||
guild = Mock(spec=discord.Guild)
|
||||
guild.id = 123456789
|
||||
guild.name = "Test Server"
|
||||
guild.description = "A test server"
|
||||
guild.member_count = 42
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_channel():
|
||||
"""Mock Discord TextChannel object"""
|
||||
channel = Mock(spec=discord.TextChannel)
|
||||
channel.id = 987654321
|
||||
channel.name = "general"
|
||||
guild = Mock()
|
||||
guild.id = 123456789
|
||||
channel.guild = guild
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dm_channel():
|
||||
"""Mock Discord DMChannel object"""
|
||||
channel = Mock(spec=discord.DMChannel)
|
||||
channel.id = 111222333
|
||||
recipient = Mock()
|
||||
recipient.name = "TestUser"
|
||||
channel.recipient = recipient
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock Discord User object"""
|
||||
user = Mock(spec=discord.User)
|
||||
user.id = 444555666
|
||||
user.name = "testuser"
|
||||
user.display_name = "Test User"
|
||||
user.bot = False
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(mock_text_channel, mock_user):
|
||||
"""Mock Discord Message object"""
|
||||
message = Mock(spec=discord.Message)
|
||||
message.id = 777888999
|
||||
message.channel = mock_text_channel
|
||||
message.author = mock_user
|
||||
message.guild = mock_text_channel.guild
|
||||
message.content = "Test message"
|
||||
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
message.reference = None
|
||||
return message
|
||||
|
||||
|
||||
# Tests for create_or_update_server
|
||||
def test_create_or_update_server_creates_new(db_session, mock_guild):
|
||||
"""Test creating a new server record"""
|
||||
result = create_or_update_server(db_session, mock_guild)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mock_guild.id
|
||||
assert result.name == mock_guild.name
|
||||
assert result.description == mock_guild.description
|
||||
assert result.member_count == mock_guild.member_count
|
||||
|
||||
|
||||
def test_create_or_update_server_updates_existing(db_session, mock_guild):
|
||||
"""Test updating an existing server record"""
|
||||
# Create initial server
|
||||
server = DiscordServer(
|
||||
id=mock_guild.id,
|
||||
name="Old Name",
|
||||
description="Old Description",
|
||||
member_count=10,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
# Update with new data
|
||||
mock_guild.name = "New Name"
|
||||
mock_guild.description = "New Description"
|
||||
mock_guild.member_count = 50
|
||||
|
||||
result = create_or_update_server(db_session, mock_guild)
|
||||
|
||||
assert result.name == "New Name"
|
||||
assert result.description == "New Description"
|
||||
assert result.member_count == 50
|
||||
assert result.last_sync_at is not None
|
||||
|
||||
|
||||
def test_create_or_update_server_none_guild(db_session):
|
||||
"""Test with None guild"""
|
||||
result = create_or_update_server(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Tests for determine_channel_metadata
|
||||
def test_determine_channel_metadata_dm():
|
||||
"""Test metadata for DM channel"""
|
||||
channel = Mock(spec=discord.DMChannel)
|
||||
channel.recipient = Mock()
|
||||
channel.recipient.name = "TestUser"
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "dm"
|
||||
assert server_id is None
|
||||
assert "DM with TestUser" in name
|
||||
|
||||
|
||||
def test_determine_channel_metadata_dm_no_recipient():
|
||||
"""Test metadata for DM channel without recipient"""
|
||||
channel = Mock(spec=discord.DMChannel)
|
||||
channel.recipient = None
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "dm"
|
||||
assert name == "Unknown DM"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_group_dm():
|
||||
"""Test metadata for group DM channel"""
|
||||
channel = Mock(spec=discord.GroupChannel)
|
||||
channel.name = "Group Chat"
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "group_dm"
|
||||
assert server_id is None
|
||||
assert name == "Group Chat"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_group_dm_no_name():
|
||||
"""Test metadata for group DM without name"""
|
||||
channel = Mock(spec=discord.GroupChannel)
|
||||
channel.name = None
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert name == "Group DM"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_text_channel():
|
||||
"""Test metadata for text channel"""
|
||||
channel = Mock(spec=discord.TextChannel)
|
||||
channel.name = "general"
|
||||
channel.guild = Mock()
|
||||
channel.guild.id = 123
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "text"
|
||||
assert server_id == 123
|
||||
assert name == "general"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_voice_channel():
|
||||
"""Test metadata for voice channel"""
|
||||
channel = Mock(spec=discord.VoiceChannel)
|
||||
channel.name = "voice-chat"
|
||||
channel.guild = Mock()
|
||||
channel.guild.id = 456
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "voice"
|
||||
assert server_id == 456
|
||||
assert name == "voice-chat"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_thread():
|
||||
"""Test metadata for thread"""
|
||||
channel = Mock(spec=discord.Thread)
|
||||
channel.name = "thread-1"
|
||||
channel.guild = Mock()
|
||||
channel.guild.id = 789
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "thread"
|
||||
assert server_id == 789
|
||||
assert name == "thread-1"
|
||||
|
||||
|
||||
def test_determine_channel_metadata_unknown():
|
||||
"""Test metadata for unknown channel type"""
|
||||
channel = Mock()
|
||||
channel.id = 999
|
||||
# Ensure the mock doesn't have a 'name' attribute
|
||||
del channel.name
|
||||
|
||||
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||
|
||||
assert channel_type == "unknown"
|
||||
assert name == "Unknown-999"
|
||||
|
||||
|
||||
# Tests for create_or_update_channel
|
||||
def test_create_or_update_channel_creates_new(
|
||||
db_session, mock_text_channel, mock_guild
|
||||
):
|
||||
"""Test creating a new channel record"""
|
||||
# Create the server first to satisfy foreign key constraint
|
||||
create_or_update_server(db_session, mock_guild)
|
||||
|
||||
result = create_or_update_channel(db_session, mock_text_channel)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mock_text_channel.id
|
||||
assert result.name == mock_text_channel.name
|
||||
assert result.channel_type == "text"
|
||||
|
||||
|
||||
def test_create_or_update_channel_updates_existing(db_session, mock_text_channel):
|
||||
"""Test updating an existing channel record"""
|
||||
# Create initial channel
|
||||
channel = DiscordChannel(
|
||||
id=mock_text_channel.id,
|
||||
name="old-name",
|
||||
channel_type="text",
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
# Update with new name
|
||||
mock_text_channel.name = "new-name"
|
||||
|
||||
result = create_or_update_channel(db_session, mock_text_channel)
|
||||
|
||||
assert result.name == "new-name"
|
||||
|
||||
|
||||
def test_create_or_update_channel_none_channel(db_session):
|
||||
"""Test with None channel"""
|
||||
result = create_or_update_channel(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Tests for create_or_update_user
|
||||
def test_create_or_update_user_creates_new(db_session, mock_user):
|
||||
"""Test creating a new user record"""
|
||||
result = create_or_update_user(db_session, mock_user)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mock_user.id
|
||||
assert result.username == mock_user.name
|
||||
assert result.display_name == mock_user.display_name
|
||||
|
||||
|
||||
def test_create_or_update_user_updates_existing(db_session, mock_user):
|
||||
"""Test updating an existing user record"""
|
||||
# Create initial user
|
||||
user = DiscordUser(
|
||||
id=mock_user.id,
|
||||
username="oldname",
|
||||
display_name="Old Display Name",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Update with new data
|
||||
mock_user.name = "newname"
|
||||
mock_user.display_name = "New Display Name"
|
||||
|
||||
result = create_or_update_user(db_session, mock_user)
|
||||
|
||||
assert result.username == "newname"
|
||||
assert result.display_name == "New Display Name"
|
||||
|
||||
|
||||
def test_create_or_update_user_none_user(db_session):
|
||||
"""Test with None user"""
|
||||
result = create_or_update_user(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Tests for determine_message_metadata
|
||||
def test_determine_message_metadata_default():
|
||||
"""Test metadata for default message"""
|
||||
message = Mock()
|
||||
message.reference = None
|
||||
message.channel = Mock()
|
||||
# Ensure channel doesn't have parent attribute
|
||||
del message.channel.parent
|
||||
|
||||
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||
|
||||
assert message_type == "default"
|
||||
assert reply_to_id is None
|
||||
assert thread_id is None
|
||||
|
||||
|
||||
def test_determine_message_metadata_reply():
|
||||
"""Test metadata for reply message"""
|
||||
message = Mock()
|
||||
message.reference = Mock()
|
||||
message.reference.message_id = 123456
|
||||
message.channel = Mock()
|
||||
|
||||
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||
|
||||
assert message_type == "reply"
|
||||
assert reply_to_id == 123456
|
||||
|
||||
|
||||
def test_determine_message_metadata_thread():
|
||||
"""Test metadata for message in thread"""
|
||||
message = Mock()
|
||||
message.reference = None
|
||||
message.channel = Mock()
|
||||
message.channel.id = 999
|
||||
message.channel.parent = Mock() # Has parent means it's a thread
|
||||
|
||||
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||
|
||||
assert thread_id == 999
|
||||
|
||||
|
||||
# Tests for should_track_message
|
||||
def test_should_track_message_server_disabled(db_session):
|
||||
"""Test when server has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=False)
|
||||
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
result = should_track_message(server, channel, user)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_should_track_message_channel_disabled(db_session):
|
||||
"""Test when channel has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", track_messages=False
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
result = should_track_message(server, channel, user)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_should_track_message_dm_allowed(db_session):
|
||||
"""Test DM tracking when user allows it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=True)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_should_track_message_dm_not_allowed(db_session):
|
||||
"""Test DM tracking when user doesn't allow it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=False)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_should_track_message_default_true(db_session):
|
||||
"""Test default tracking behavior"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", track_messages=True
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
result = should_track_message(server, channel, user)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
# Tests for should_collect_bot_message
|
||||
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", False)
|
||||
def test_should_collect_bot_message_bot_not_allowed():
|
||||
"""Test bot message collection when disabled"""
|
||||
message = Mock()
|
||||
message.author = Mock()
|
||||
message.author.bot = True
|
||||
|
||||
result = should_collect_bot_message(message)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", True)
|
||||
def test_should_collect_bot_message_bot_allowed():
|
||||
"""Test bot message collection when enabled"""
|
||||
message = Mock()
|
||||
message.author = Mock()
|
||||
message.author.bot = True
|
||||
|
||||
result = should_collect_bot_message(message)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_should_collect_bot_message_human():
|
||||
"""Test human message collection"""
|
||||
message = Mock()
|
||||
message.author = Mock()
|
||||
message.author.bot = False
|
||||
|
||||
result = should_collect_bot_message(message)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
# Tests for sync_guild_metadata
|
||||
@patch("memory.discord.collector.make_session")
|
||||
def test_sync_guild_metadata(mock_make_session, mock_guild):
|
||||
"""Test syncing guild metadata"""
|
||||
mock_session = Mock()
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock session.query().get() to return None (new server)
|
||||
mock_session.query.return_value.get.return_value = None
|
||||
|
||||
# Mock channels
|
||||
text_channel = Mock(spec=discord.TextChannel)
|
||||
text_channel.id = 1
|
||||
text_channel.name = "general"
|
||||
text_channel.guild = mock_guild
|
||||
|
||||
voice_channel = Mock(spec=discord.VoiceChannel)
|
||||
voice_channel.id = 2
|
||||
voice_channel.name = "voice"
|
||||
voice_channel.guild = mock_guild
|
||||
|
||||
mock_guild.channels = [text_channel, voice_channel]
|
||||
|
||||
sync_guild_metadata(mock_guild)
|
||||
|
||||
# Verify session.commit was called
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# Tests for MessageCollector class
|
||||
def test_message_collector_init():
|
||||
"""Test MessageCollector initialization"""
|
||||
collector = MessageCollector()
|
||||
|
||||
assert collector.command_prefix == "!memory_"
|
||||
assert collector.help_command is None
|
||||
assert collector.intents.message_content is True
|
||||
assert collector.intents.guilds is True
|
||||
assert collector.intents.members is True
|
||||
assert collector.intents.dm_messages is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_ready():
|
||||
"""Test on_ready event handler"""
|
||||
collector = MessageCollector()
|
||||
collector.user = Mock()
|
||||
collector.user.name = "TestBot"
|
||||
collector.guilds = [Mock(), Mock()]
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
|
||||
await collector.on_ready()
|
||||
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
@patch("memory.discord.collector.add_discord_message")
|
||||
async def test_on_message_success(mock_add_task, mock_make_session, mock_message):
|
||||
"""Test successful message handling"""
|
||||
mock_session = Mock()
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||
mock_session.query.return_value.get.return_value = None # New entities
|
||||
|
||||
collector = MessageCollector()
|
||||
await collector.on_message(mock_message)
|
||||
|
||||
# Verify task was queued
|
||||
mock_add_task.delay.assert_called_once()
|
||||
call_kwargs = mock_add_task.delay.call_args[1]
|
||||
assert call_kwargs["message_id"] == mock_message.id
|
||||
assert call_kwargs["channel_id"] == mock_message.channel.id
|
||||
assert call_kwargs["author_id"] == mock_message.author.id
|
||||
assert call_kwargs["content"] == mock_message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
async def test_on_message_bot_message_filtered(mock_make_session, mock_message):
|
||||
"""Test bot message filtering"""
|
||||
mock_message.author.bot = True
|
||||
|
||||
with patch(
|
||||
"memory.discord.collector.should_collect_bot_message", return_value=False
|
||||
):
|
||||
collector = MessageCollector()
|
||||
await collector.on_message(mock_message)
|
||||
|
||||
# Should not create session or queue task
|
||||
mock_make_session.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
async def test_on_message_error_handling(mock_make_session, mock_message):
|
||||
"""Test error handling in on_message"""
|
||||
mock_make_session.side_effect = Exception("Database error")
|
||||
|
||||
collector = MessageCollector()
|
||||
# Should not raise
|
||||
await collector.on_message(mock_message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.edit_discord_message")
|
||||
async def test_on_message_edit(mock_edit_task):
|
||||
"""Test message edit handler"""
|
||||
before = Mock()
|
||||
after = Mock()
|
||||
after.id = 123
|
||||
after.content = "Edited content"
|
||||
after.edited_at = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
collector = MessageCollector()
|
||||
await collector.on_message_edit(before, after)
|
||||
|
||||
mock_edit_task.delay.assert_called_once()
|
||||
call_kwargs = mock_edit_task.delay.call_args[1]
|
||||
assert call_kwargs["message_id"] == 123
|
||||
assert call_kwargs["content"] == "Edited content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_edit_error_handling():
|
||||
"""Test error handling in on_message_edit"""
|
||||
before = Mock()
|
||||
after = Mock()
|
||||
after.id = 123
|
||||
after.content = "Edited"
|
||||
after.edited_at = None # Will trigger datetime.now
|
||||
|
||||
with patch("memory.discord.collector.edit_discord_message") as mock_edit:
|
||||
mock_edit.delay.side_effect = Exception("Task error")
|
||||
|
||||
collector = MessageCollector()
|
||||
# Should not raise
|
||||
await collector.on_message_edit(before, after)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_servers_and_channels():
|
||||
"""Test syncing servers and channels"""
|
||||
guild1 = Mock()
|
||||
guild2 = Mock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild1, guild2]
|
||||
|
||||
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||
await collector.sync_servers_and_channels()
|
||||
|
||||
assert mock_sync.call_count == 2
|
||||
mock_sync.assert_any_call(guild1)
|
||||
mock_sync.assert_any_call(guild2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.collector.make_session")
|
||||
async def test_refresh_metadata(mock_make_session):
|
||||
"""Test metadata refresh"""
|
||||
mock_session = Mock()
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||
mock_session.query.return_value.get.return_value = None
|
||||
|
||||
guild = Mock()
|
||||
guild.id = 123
|
||||
guild.name = "Test"
|
||||
guild.channels = []
|
||||
guild.members = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
collector.intents = Mock()
|
||||
collector.intents.members = False
|
||||
|
||||
result = await collector.refresh_metadata()
|
||||
|
||||
assert result["servers_updated"] == 1
|
||||
assert result["channels_updated"] == 0
|
||||
assert result["users_updated"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_id():
|
||||
"""Test getting user by ID"""
|
||||
user = Mock()
|
||||
user.id = 123
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = Mock(return_value=user)
|
||||
|
||||
result = await collector.get_user(123)
|
||||
|
||||
assert result == user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_username():
|
||||
"""Test getting user by username"""
|
||||
member = Mock()
|
||||
member.name = "testuser"
|
||||
member.display_name = "Test User"
|
||||
member.discriminator = "1234"
|
||||
|
||||
guild = Mock()
|
||||
guild.members = [member]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_user("testuser")
|
||||
|
||||
assert result == member
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_not_found():
|
||||
"""Test getting non-existent user"""
|
||||
collector = MessageCollector()
|
||||
collector.guilds = []
|
||||
|
||||
with patch.object(collector, "get_user", return_value=None):
|
||||
with patch.object(
|
||||
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
|
||||
):
|
||||
result = await collector.get_user(999)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_channel_by_name():
|
||||
"""Test getting channel by name"""
|
||||
channel = Mock(spec=discord.TextChannel)
|
||||
channel.name = "general"
|
||||
|
||||
guild = Mock()
|
||||
guild.channels = [channel]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_channel_by_name("general")
|
||||
|
||||
assert result == channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_channel_by_name_not_found():
|
||||
"""Test getting non-existent channel"""
|
||||
guild = Mock()
|
||||
guild.channels = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_channel_by_name("nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_channel():
|
||||
"""Test creating a channel"""
|
||||
guild = Mock()
|
||||
guild.name = "Test Server"
|
||||
new_channel = Mock()
|
||||
guild.create_text_channel = AsyncMock(return_value=new_channel)
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_guild = Mock(return_value=guild)
|
||||
|
||||
result = await collector.create_channel("new-channel", guild_id=123)
|
||||
|
||||
assert result == new_channel
|
||||
guild.create_text_channel.assert_called_once_with("new-channel")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_channel_no_guild():
|
||||
"""Test creating channel when no guild available"""
|
||||
collector = MessageCollector()
|
||||
collector.get_guild = Mock(return_value=None)
|
||||
collector.guilds = []
|
||||
|
||||
result = await collector.create_channel("new-channel")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_success():
|
||||
"""Test sending DM successfully"""
|
||||
user = Mock()
|
||||
user.send = AsyncMock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = AsyncMock(return_value=user)
|
||||
|
||||
result = await collector.send_dm(123, "Hello!")
|
||||
|
||||
assert result is True
|
||||
user.send.assert_called_once_with("Hello!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_user_not_found():
|
||||
"""Test sending DM when user not found"""
|
||||
collector = MessageCollector()
|
||||
collector.get_user = AsyncMock(return_value=None)
|
||||
|
||||
result = await collector.send_dm(123, "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_exception():
|
||||
"""Test sending DM with exception"""
|
||||
user = Mock()
|
||||
user.send = AsyncMock(side_effect=Exception("Send failed"))
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = AsyncMock(return_value=user)
|
||||
|
||||
result = await collector.send_dm(123, "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
async def test_send_to_channel_success():
|
||||
"""Test sending to channel successfully"""
|
||||
channel = Mock()
|
||||
channel.send = AsyncMock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_channel_by_name = AsyncMock(return_value=channel)
|
||||
|
||||
result = await collector.send_to_channel("general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
channel.send.assert_called_once_with("Announcement!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
async def test_send_to_channel_notifications_disabled():
|
||||
"""Test sending to channel when notifications disabled"""
|
||||
collector = MessageCollector()
|
||||
|
||||
result = await collector.send_to_channel("general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
async def test_send_to_channel_not_found():
|
||||
"""Test sending to non-existent channel"""
|
||||
collector = MessageCollector()
|
||||
collector.get_channel_by_name = AsyncMock(return_value=None)
|
||||
|
||||
result = await collector.send_to_channel("nonexistent", "Message")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
async def test_run_collector():
|
||||
"""Test running the collector"""
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
|
||||
mock_collector = Mock()
|
||||
mock_collector.start = AsyncMock()
|
||||
mock_collector_class.return_value = mock_collector
|
||||
|
||||
await run_collector()
|
||||
|
||||
mock_collector.start.assert_called_once_with("test_token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||
async def test_run_collector_no_token():
|
||||
"""Test running collector without token"""
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
# Should return early without raising
|
||||
await run_collector()
|
||||
607
tests/memory/workers/tasks/test_discord_tasks.py
Normal file
607
tests/memory/workers/tasks/test_discord_tasks.py
Normal file
@ -0,0 +1,607 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import (
|
||||
DiscordMessage,
|
||||
DiscordUser,
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
)
|
||||
from memory.workers.tasks import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(db_session):
|
||||
"""Create a Discord user for testing."""
|
||||
user = DiscordUser(
|
||||
id=123456789,
|
||||
username="testuser",
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_server(db_session):
|
||||
"""Create a Discord server for testing."""
|
||||
server = DiscordServer(
|
||||
id=987654321,
|
||||
name="Test Server",
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_channel(db_session, mock_discord_server):
|
||||
"""Create a Discord channel for testing."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="test-channel",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_data(mock_discord_user, mock_discord_channel):
|
||||
"""Sample message data for testing."""
|
||||
return {
|
||||
"message_id": 999888777,
|
||||
"channel_id": mock_discord_channel.id,
|
||||
"author_id": mock_discord_user.id,
|
||||
"content": "This is a test Discord message with enough content to be processed.",
|
||||
"sent_at": "2024-01-01T12:00:00Z",
|
||||
"server_id": None,
|
||||
"message_reference_id": None,
|
||||
}
|
||||
|
||||
|
||||
def test_get_prev_returns_previous_messages(
|
||||
db_session, mock_discord_user, mock_discord_channel
|
||||
):
|
||||
"""Test that get_prev returns previous messages in order."""
|
||||
# Create previous messages
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="First message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash1" + bytes(26),
|
||||
)
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Second message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash2" + bytes(26),
|
||||
)
|
||||
msg3 = DiscordMessage(
|
||||
message_id=3,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Third message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash3" + bytes(26),
|
||||
)
|
||||
db_session.add_all([msg1, msg2, msg3])
|
||||
db_session.commit()
|
||||
|
||||
# Get previous messages before 10:15
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
mock_discord_channel.id,
|
||||
datetime(2024, 1, 1, 10, 15, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0] == "testuser: First message"
|
||||
assert result[1] == "testuser: Second message"
|
||||
assert result[2] == "testuser: Third message"
|
||||
|
||||
|
||||
def test_get_prev_limits_context_window(
|
||||
db_session, mock_discord_user, mock_discord_channel
|
||||
):
|
||||
"""Test that get_prev respects DISCORD_CONTEXT_WINDOW setting."""
|
||||
# Create 15 messages (more than the default context window of 10)
|
||||
for i in range(15):
|
||||
msg = DiscordMessage(
|
||||
message_id=i,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content=f"Message {i}",
|
||||
sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=f"hash{i}".encode() + bytes(27),
|
||||
)
|
||||
db_session.add(msg)
|
||||
db_session.commit()
|
||||
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
mock_discord_channel.id,
|
||||
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# Should only return last 10 messages
|
||||
assert len(result) == 10
|
||||
assert result[0] == "testuser: Message 5" # Oldest in window
|
||||
assert result[-1] == "testuser: Message 14" # Most recent
|
||||
|
||||
|
||||
def test_get_prev_empty_channel(db_session, mock_discord_channel):
|
||||
"""Test get_prev with no previous messages."""
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
mock_discord_channel.id,
|
||||
datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_normal_message(
|
||||
db_session, mock_discord_user, mock_discord_server, mock_discord_channel
|
||||
):
|
||||
"""Test should_process returns True for normal messages."""
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is True
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||
def test_should_process_disabled():
|
||||
"""Test should_process returns False when processing is disabled."""
|
||||
message = Mock()
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_should_process_notifications_disabled():
|
||||
"""Test should_process returns False when notifications are disabled."""
|
||||
message = Mock()
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_server_ignored(
|
||||
db_session, mock_discord_user, mock_discord_channel
|
||||
):
|
||||
"""Test should_process returns False when server has ignore_messages=True."""
|
||||
server = DiscordServer(
|
||||
id=123,
|
||||
name="Ignored Server",
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
server_id=server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_channel_ignored(
|
||||
db_session, mock_discord_user, mock_discord_server
|
||||
):
|
||||
"""Test should_process returns False when channel has ignore_messages=True."""
|
||||
channel = DiscordChannel(
|
||||
id=456,
|
||||
name="ignored-channel",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_should_process_user_ignored(
|
||||
db_session, mock_discord_server, mock_discord_channel
|
||||
):
|
||||
"""Test should_process returns False when user has ignore_messages=True."""
|
||||
user = DiscordUser(
|
||||
id=789,
|
||||
username="ignoreduser",
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is False
|
||||
|
||||
|
||||
def test_add_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
"""Test successful Discord message addition."""
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert "discordmessage_id" in result
|
||||
|
||||
# Verify the message was created in the database
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message is not None
|
||||
assert message.content == sample_message_data["content"]
|
||||
assert message.message_type == "default"
|
||||
assert message.reply_to_message_id is None
|
||||
|
||||
|
||||
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
|
||||
"""Test adding a Discord message that is a reply."""
|
||||
sample_message_data["message_reference_id"] = 111222333
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message.message_type == "reply"
|
||||
assert message.reply_to_message_id == 111222333
|
||||
|
||||
|
||||
def test_add_discord_message_already_exists(db_session, sample_message_data, qdrant):
|
||||
"""Test adding a message that already exists."""
|
||||
# Add the message once
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Try to add it again
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
assert result["status"] == "already_exists"
|
||||
assert result["message_id"] == sample_message_data["message_id"]
|
||||
|
||||
# Verify no duplicate was created
|
||||
messages = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.all()
|
||||
)
|
||||
assert len(messages) == 1
|
||||
|
||||
|
||||
def test_add_discord_message_with_context(
|
||||
db_session, sample_message_data, mock_discord_user, qdrant
|
||||
):
|
||||
"""Test that message is added successfully when previous messages exist."""
|
||||
# Add a previous message
|
||||
prev_msg = DiscordMessage(
|
||||
message_id=111111,
|
||||
channel_id=sample_message_data["channel_id"],
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Previous message",
|
||||
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"prev" + bytes(28),
|
||||
)
|
||||
db_session.add(prev_msg)
|
||||
db_session.commit()
|
||||
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message is not None
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.discord.process_discord_message")
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_add_discord_message_triggers_processing(
|
||||
mock_process,
|
||||
db_session,
|
||||
sample_message_data,
|
||||
mock_discord_server,
|
||||
mock_discord_channel,
|
||||
qdrant,
|
||||
):
|
||||
"""Test that add_discord_message triggers process_discord_message when conditions are met."""
|
||||
mock_process.delay = Mock()
|
||||
sample_message_data["server_id"] = mock_discord_server.id
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Verify process_discord_message.delay was called
|
||||
mock_process.delay.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.discord.process_discord_message")
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||
def test_add_discord_message_no_processing_when_disabled(
|
||||
mock_process, db_session, sample_message_data, qdrant
|
||||
):
|
||||
"""Test that process_discord_message is not called when processing is disabled."""
|
||||
mock_process.delay = Mock()
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
mock_process.delay.assert_not_called()
|
||||
|
||||
|
||||
def test_edit_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
"""Test successful Discord message edit."""
|
||||
# First add the message
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Edit it
|
||||
new_content = (
|
||||
"This is the edited content with enough text to be meaningful and processed."
|
||||
)
|
||||
edited_at = "2024-01-01T13:00:00Z"
|
||||
|
||||
result = discord.edit_discord_message(
|
||||
sample_message_data["message_id"],
|
||||
new_content,
|
||||
edited_at,
|
||||
)
|
||||
|
||||
assert result["status"] == "processed"
|
||||
|
||||
# Verify the message was updated
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message.content == new_content
|
||||
assert message.edited_at is not None
|
||||
|
||||
|
||||
def test_edit_discord_message_not_found(db_session):
|
||||
"""Test editing a message that doesn't exist."""
|
||||
result = discord.edit_discord_message(
|
||||
999999,
|
||||
"New content",
|
||||
"2024-01-01T13:00:00Z",
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Message not found"
|
||||
assert result["message_id"] == 999999
|
||||
|
||||
|
||||
def test_edit_discord_message_updates_context(
|
||||
db_session, sample_message_data, mock_discord_user, qdrant
|
||||
):
|
||||
"""Test that editing a message works correctly."""
|
||||
# Add previous message and the message to be edited
|
||||
prev_msg = DiscordMessage(
|
||||
message_id=111111,
|
||||
channel_id=sample_message_data["channel_id"],
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Previous message",
|
||||
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"prev" + bytes(28),
|
||||
)
|
||||
db_session.add(prev_msg)
|
||||
db_session.commit()
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Edit the message
|
||||
result = discord.edit_discord_message(
|
||||
sample_message_data["message_id"],
|
||||
"Edited content that should have context updated properly.",
|
||||
"2024-01-01T13:00:00Z",
|
||||
)
|
||||
|
||||
# Verify message was updated
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert (
|
||||
message.content == "Edited content that should have context updated properly."
|
||||
)
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
"""Test processing a Discord message."""
|
||||
# Add a message first
|
||||
add_result = discord.add_discord_message(**sample_message_data)
|
||||
message_id = add_result["discordmessage_id"]
|
||||
|
||||
# Process it
|
||||
result = discord.process_discord_message(message_id)
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert result["message_id"] == message_id
|
||||
|
||||
|
||||
def test_process_discord_message_not_found(db_session):
|
||||
"""Test processing a message that doesn't exist."""
|
||||
result = discord.process_discord_message(999999)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Message not found"
|
||||
assert result["message_id"] == 999999
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sent_at_str,expected_hour",
|
||||
[
|
||||
("2024-01-01T12:00:00Z", 12),
|
||||
("2024-01-01T00:00:00+00:00", 0),
|
||||
("2024-01-01T23:59:59Z", 23),
|
||||
],
|
||||
)
|
||||
def test_add_discord_message_datetime_parsing(
|
||||
db_session, sample_message_data, sent_at_str, expected_hour, qdrant
|
||||
):
|
||||
"""Test that various datetime formats are parsed correctly."""
|
||||
sample_message_data["sent_at"] = sent_at_str
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
message = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(message_id=sample_message_data["message_id"])
|
||||
.first()
|
||||
)
|
||||
assert message.sent_at.hour == expected_hour
|
||||
|
||||
|
||||
def test_add_discord_message_unique_hash(db_session, sample_message_data, qdrant):
|
||||
"""Test that message hash includes message_id for uniqueness."""
|
||||
# Add first message
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Try to add another message with same content but different message_id
|
||||
sample_message_data["message_id"] = 888777666
|
||||
|
||||
result = discord.add_discord_message(**sample_message_data)
|
||||
|
||||
# Should succeed because hash includes message_id
|
||||
assert result["status"] == "processed"
|
||||
|
||||
# Verify both messages exist
|
||||
messages = (
|
||||
db_session.query(DiscordMessage)
|
||||
.filter_by(content=sample_message_data["content"])
|
||||
.all()
|
||||
)
|
||||
assert len(messages) == 2
|
||||
|
||||
|
||||
def test_get_prev_only_returns_messages_from_same_channel(
|
||||
db_session, mock_discord_user, mock_discord_server
|
||||
):
|
||||
"""Test that get_prev only returns messages from the specified channel."""
|
||||
# Create two channels
|
||||
channel1 = DiscordChannel(
|
||||
id=111,
|
||||
name="channel-1",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
)
|
||||
channel2 = DiscordChannel(
|
||||
id=222,
|
||||
name="channel-2",
|
||||
channel_type="text",
|
||||
server_id=mock_discord_server.id,
|
||||
)
|
||||
db_session.add_all([channel1, channel2])
|
||||
db_session.commit()
|
||||
|
||||
# Add messages to both channels
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=channel1.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Message in channel 1",
|
||||
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash1" + bytes(26),
|
||||
)
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=channel2.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
content="Message in channel 2",
|
||||
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash2" + bytes(26),
|
||||
)
|
||||
db_session.add_all([msg1, msg2])
|
||||
db_session.commit()
|
||||
|
||||
# Get previous messages for channel 1
|
||||
result = discord.get_prev(
|
||||
db_session,
|
||||
channel1.id, # type: ignore
|
||||
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# Should only return message from channel 1
|
||||
assert len(result) == 1
|
||||
assert "Message in channel 1" in result[0]
|
||||
assert "Message in channel 2" not in result[0]
|
||||
Loading…
x
Reference in New Issue
Block a user