mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 09:01:17 +01:00
Compare commits
3 Commits
f454aa9afa
...
99d3843f47
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99d3843f47 | ||
|
|
08d17c28dd | ||
|
|
e086b4a3a6 |
@ -9,7 +9,6 @@ Create Date: 2025-10-12 10:12:57.421009
|
|||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
@ -20,10 +19,8 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Rename prompt column to message in scheduled_llm_calls table
|
|
||||||
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
|
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Rename message column back to prompt in scheduled_llm_calls table
|
|
||||||
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")
|
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")
|
||||||
|
|||||||
176
db/migrations/versions/20251012_222827_add_discord_models.py
Normal file
176
db/migrations/versions/20251012_222827_add_discord_models.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
"""add_discord_models
|
||||||
|
|
||||||
|
Revision ID: a8c8e8b17179
|
||||||
|
Revises: c86079073c1d
|
||||||
|
Create Date: 2025-10-12 22:28:27.856164
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a8c8e8b17179"
|
||||||
|
down_revision: Union[str, None] = "c86079073c1d"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"discord_servers",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("member_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_servers_active_idx",
|
||||||
|
"discord_servers",
|
||||||
|
["track_messages", "last_sync_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"discord_channels",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("channel_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["server_id"],
|
||||||
|
["discord_servers.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_channels_server_idx", "discord_channels", ["server_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"discord_users",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("username", sa.Text(), nullable=False),
|
||||||
|
sa.Column("display_name", sa.Text(), nullable=True),
|
||||||
|
sa.Column("system_user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["system_user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_users_system_user_idx",
|
||||||
|
"discord_users",
|
||||||
|
["system_user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"discord_message",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("discord_user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("message_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("message_type", sa.Text(), server_default="default", nullable=True),
|
||||||
|
sa.Column("reply_to_message_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("thread_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["channel_id"],
|
||||||
|
["discord_channels.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["discord_user_id"],
|
||||||
|
["discord_users.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["server_id"],
|
||||||
|
["discord_servers.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_discord_id_idx", "discord_message", ["message_id"], unique=True
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_server_channel_idx",
|
||||||
|
"discord_message",
|
||||||
|
["server_id", "channel_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("discord_message_user_idx", table_name="discord_message")
|
||||||
|
op.drop_index("discord_message_server_channel_idx", table_name="discord_message")
|
||||||
|
op.drop_index("discord_message_discord_id_idx", table_name="discord_message")
|
||||||
|
op.drop_table("discord_message")
|
||||||
|
op.drop_index("discord_users_system_user_idx", table_name="discord_users")
|
||||||
|
op.drop_table("discord_users")
|
||||||
|
op.drop_index("discord_channels_server_idx", table_name="discord_channels")
|
||||||
|
op.drop_table("discord_channels")
|
||||||
|
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
|
||||||
|
op.drop_table("discord_servers")
|
||||||
@ -174,7 +174,7 @@ services:
|
|||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
|
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
|
||||||
|
|
||||||
ingest-hub:
|
ingest-hub:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
@ -183,6 +183,10 @@ services:
|
|||||||
dockerfile: docker/ingest_hub/Dockerfile
|
dockerfile: docker/ingest_hub/Dockerfile
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
|
DISCORD_API_PORT: 8000
|
||||||
|
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
|
||||||
|
DISCORD_NOTIFICATIONS_ENABLED: true
|
||||||
|
DISCORD_COLLECTOR_ENABLED: true
|
||||||
volumes:
|
volumes:
|
||||||
- ./memory_files:/app/memory_files:rw
|
- ./memory_files:/app/memory_files:rw
|
||||||
tmpfs:
|
tmpfs:
|
||||||
|
|||||||
@ -11,10 +11,10 @@ RUN apt-get update && apt-get install -y \
|
|||||||
COPY requirements ./requirements/
|
COPY requirements ./requirements/
|
||||||
COPY setup.py ./
|
COPY setup.py ./
|
||||||
RUN mkdir src
|
RUN mkdir src
|
||||||
RUN pip install -e ".[common]"
|
RUN pip install -e ".[ingesters]"
|
||||||
|
|
||||||
COPY src/ ./src/
|
COPY src/ ./src/
|
||||||
RUN pip install -e ".[common]"
|
RUN pip install -e ".[ingesters]"
|
||||||
|
|
||||||
# Create and copy entrypoint script
|
# Create and copy entrypoint script
|
||||||
COPY docker/workers/entry.sh ./entry.sh
|
COPY docker/workers/entry.sh ./entry.sh
|
||||||
|
|||||||
@ -14,3 +14,12 @@ stderr_logfile=/dev/stderr
|
|||||||
stderr_logfile_maxbytes=0
|
stderr_logfile_maxbytes=0
|
||||||
autorestart=true
|
autorestart=true
|
||||||
startsecs=10
|
startsecs=10
|
||||||
|
|
||||||
|
[program:discord-api]
|
||||||
|
command=uvicorn memory.discord.api:app --host 0.0.0.0 --port %(ENV_DISCORD_API_PORT)s
|
||||||
|
stdout_logfile=/dev/stdout
|
||||||
|
stdout_logfile_maxbytes=0
|
||||||
|
stderr_logfile=/dev/stderr
|
||||||
|
stderr_logfile_maxbytes=0
|
||||||
|
autorestart=true
|
||||||
|
startsecs=10
|
||||||
|
|||||||
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
|
|||||||
git config --global user.name "${GIT_USER_NAME}"
|
git config --global user.name "${GIT_USER_NAME}"
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
|
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["./entry.sh"]
|
ENTRYPOINT ["./entry.sh"]
|
||||||
@ -5,7 +5,7 @@ alembic==1.13.1
|
|||||||
dotenv==0.9.9
|
dotenv==0.9.9
|
||||||
voyageai==0.3.2
|
voyageai==0.3.2
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
anthropic==0.18.1
|
anthropic==0.69.0
|
||||||
openai==1.25.0
|
openai==1.25.0
|
||||||
# Pin the httpx version, as newer versions break the anthropic client
|
# Pin the httpx version, as newer versions break the anthropic client
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
|
|||||||
3
requirements/requirements-ingesters.txt
Normal file
3
requirements/requirements-ingesters.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
discord.py==2.3.2
|
||||||
|
uvicorn==0.29.0
|
||||||
|
fastapi==0.112.2
|
||||||
8
setup.py
8
setup.py
@ -17,6 +17,7 @@ common_requires = read_requirements("requirements-common.txt")
|
|||||||
parsers_requires = read_requirements("requirements-parsers.txt")
|
parsers_requires = read_requirements("requirements-parsers.txt")
|
||||||
api_requires = read_requirements("requirements-api.txt")
|
api_requires = read_requirements("requirements-api.txt")
|
||||||
dev_requires = read_requirements("requirements-dev.txt")
|
dev_requires = read_requirements("requirements-dev.txt")
|
||||||
|
ingesters_requires = read_requirements("requirements-ingesters.txt")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="memory",
|
name="memory",
|
||||||
@ -28,6 +29,11 @@ setup(
|
|||||||
"api": api_requires + common_requires + parsers_requires,
|
"api": api_requires + common_requires + parsers_requires,
|
||||||
"common": common_requires + parsers_requires,
|
"common": common_requires + parsers_requires,
|
||||||
"dev": dev_requires,
|
"dev": dev_requires,
|
||||||
"all": api_requires + common_requires + dev_requires + parsers_requires,
|
"ingesters": common_requires + parsers_requires + ingesters_requires,
|
||||||
|
"all": api_requires
|
||||||
|
+ common_requires
|
||||||
|
+ dev_requires
|
||||||
|
+ parsers_requires
|
||||||
|
+ ingesters_requires,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -40,7 +40,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
|||||||
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||||
try:
|
try:
|
||||||
response = await asyncio.to_thread(
|
response = await asyncio.to_thread(
|
||||||
llms.call,
|
llms.summarize,
|
||||||
prompt,
|
prompt,
|
||||||
settings.RANKER_MODEL,
|
settings.RANKER_MODEL,
|
||||||
images=images,
|
images=images,
|
||||||
|
|||||||
@ -12,6 +12,10 @@ MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
|||||||
NOTES_ROOT = "memory.workers.tasks.notes"
|
NOTES_ROOT = "memory.workers.tasks.notes"
|
||||||
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||||
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||||
|
DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||||
|
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||||
|
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||||
|
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||||
|
|
||||||
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
||||||
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
||||||
@ -72,17 +76,18 @@ app.conf.update(
|
|||||||
task_reject_on_worker_lost=True,
|
task_reject_on_worker_lost=True,
|
||||||
worker_prefetch_multiplier=1,
|
worker_prefetch_multiplier=1,
|
||||||
task_routes={
|
task_routes={
|
||||||
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
|
||||||
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
|
||||||
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
|
||||||
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
||||||
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
||||||
|
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
||||||
|
f"{DISCORD_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
|
||||||
|
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
||||||
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
|
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
|
||||||
f"{MAINTENANCE_ROOT}.*": {
|
f"{MAINTENANCE_ROOT}.*": {
|
||||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
|
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
|
||||||
},
|
},
|
||||||
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
|
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
||||||
f"{SCHEDULED_CALLS_ROOT}.*": {
|
f"{SCHEDULED_CALLS_ROOT}.*": {
|
||||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
||||||
},
|
},
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from memory.common.db.models.source_items import (
|
|||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
AgentObservation,
|
AgentObservation,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
DiscordMessage,
|
||||||
BlogPost,
|
BlogPost,
|
||||||
Comic,
|
Comic,
|
||||||
BookSection,
|
BookSection,
|
||||||
@ -40,6 +41,9 @@ from memory.common.db.models.sources import (
|
|||||||
Book,
|
Book,
|
||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.users import (
|
from memory.common.db.models.users import (
|
||||||
User,
|
User,
|
||||||
@ -74,6 +78,7 @@ __all__ = [
|
|||||||
"EmailAttachment",
|
"EmailAttachment",
|
||||||
"AgentObservation",
|
"AgentObservation",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
|
"DiscordMessage",
|
||||||
"BlogPost",
|
"BlogPost",
|
||||||
"Comic",
|
"Comic",
|
||||||
"BookSection",
|
"BookSection",
|
||||||
@ -93,6 +98,9 @@ __all__ = [
|
|||||||
"Book",
|
"Book",
|
||||||
"ArticleFeed",
|
"ArticleFeed",
|
||||||
"EmailAccount",
|
"EmailAccount",
|
||||||
|
"DiscordServer",
|
||||||
|
"DiscordChannel",
|
||||||
|
"DiscordUser",
|
||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"UserSession",
|
"UserSession",
|
||||||
|
|||||||
@ -70,7 +70,7 @@ class ScheduledLLMCall(Base):
|
|||||||
"created_at": print_datetime(cast(datetime, self.created_at)),
|
"created_at": print_datetime(cast(datetime, self.created_at)),
|
||||||
"executed_at": print_datetime(cast(datetime, self.executed_at)),
|
"executed_at": print_datetime(cast(datetime, self.executed_at)),
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": self.message,
|
"message": self.message,
|
||||||
"system_prompt": self.system_prompt,
|
"system_prompt": self.system_prompt,
|
||||||
"allowed_tools": self.allowed_tools,
|
"allowed_tools": self.allowed_tools,
|
||||||
"discord_channel": self.discord_channel,
|
"discord_channel": self.discord_channel,
|
||||||
|
|||||||
@ -262,7 +262,7 @@ class ChatMessage(SourceItem):
|
|||||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
platform = Column(Text)
|
platform = Column(Text)
|
||||||
channel_id = Column(Text)
|
channel_id = Column(Text) # Keep as Text for cross-platform compatibility
|
||||||
author = Column(Text)
|
author = Column(Text)
|
||||||
sent_at = Column(DateTime(timezone=True))
|
sent_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
@ -274,6 +274,64 @@ class ChatMessage(SourceItem):
|
|||||||
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
|
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordMessage(SourceItem):
|
||||||
|
"""Discord-specific chat message with rich metadata"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_message"
|
||||||
|
|
||||||
|
id = Column(
|
||||||
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||||
|
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
|
||||||
|
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
|
||||||
|
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
|
||||||
|
|
||||||
|
# Discord-specific metadata
|
||||||
|
message_type = Column(
|
||||||
|
Text, server_default="default"
|
||||||
|
) # "default", "reply", "thread_starter"
|
||||||
|
reply_to_message_id = Column(
|
||||||
|
BigInteger, nullable=True
|
||||||
|
) # Discord message snowflake ID if replying
|
||||||
|
thread_id = Column(
|
||||||
|
BigInteger, nullable=True
|
||||||
|
) # Discord thread snowflake ID if in thread
|
||||||
|
edited_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
||||||
|
server = relationship("DiscordServer", foreign_keys=[server_id])
|
||||||
|
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> str:
|
||||||
|
return f"{self.discord_user.username}: {self.content}"
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "discord_message",
|
||||||
|
}
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("discord_message_discord_id_idx", "message_id", unique=True),
|
||||||
|
Index(
|
||||||
|
"discord_message_server_channel_idx",
|
||||||
|
"server_id",
|
||||||
|
"channel_id",
|
||||||
|
),
|
||||||
|
Index("discord_message_user_idx", "discord_user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
|
content = cast(str | None, self.content)
|
||||||
|
if not content:
|
||||||
|
return []
|
||||||
|
prev = getattr(self, "messages_before", [])
|
||||||
|
content = "\n\n".join(prev) + "\n\n" + self.title
|
||||||
|
return extract.extract_text(content)
|
||||||
|
|
||||||
|
|
||||||
class GitCommit(SourceItem):
|
class GitCommit(SourceItem):
|
||||||
__tablename__ = "git_commit"
|
__tablename__ = "git_commit"
|
||||||
|
|
||||||
|
|||||||
@ -10,12 +10,14 @@ from sqlalchemy import (
|
|||||||
Boolean,
|
Boolean,
|
||||||
Column,
|
Column,
|
||||||
DateTime,
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
Index,
|
Index,
|
||||||
Integer,
|
Integer,
|
||||||
Text,
|
Text,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
@ -123,3 +125,74 @@ class EmailAccount(Base):
|
|||||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageProcessor:
|
||||||
|
track_messages = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
ignore_messages = Column(Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
|
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordServer(Base, MessageProcessor):
|
||||||
|
"""Discord server configuration and metadata"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_servers"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
description = Column(Text)
|
||||||
|
member_count = Column(Integer)
|
||||||
|
|
||||||
|
# Collection settings
|
||||||
|
last_sync_at = Column(DateTime(timezone=True))
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
channels = relationship(
|
||||||
|
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordChannel(Base, MessageProcessor):
|
||||||
|
"""Discord channel metadata and configuration"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_channels"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
|
||||||
|
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
|
||||||
|
|
||||||
|
# Collection settings (null = inherit from server)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
server = relationship("DiscordServer", back_populates="channels")
|
||||||
|
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordUser(Base, MessageProcessor):
|
||||||
|
"""Discord user metadata and preferences"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_users"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
|
||||||
|
username = Column(Text, nullable=False)
|
||||||
|
display_name = Column(Text)
|
||||||
|
|
||||||
|
# Link to system user if registered
|
||||||
|
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||||
|
|
||||||
|
# Basic DM settings
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
system_user = relationship("User", back_populates="discord_users")
|
||||||
|
|
||||||
|
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class User(Base):
|
|||||||
oauth_states = relationship(
|
oauth_states = relationship(
|
||||||
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
discord_users = relationship("DiscordUser", back_populates="system_user")
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
def serialize(self) -> dict:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -1,221 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Discord integration.
|
||||||
|
|
||||||
|
Simple HTTP client that communicates with the Discord collector's API server.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ERROR_CHANNEL = "memory-errors"
|
|
||||||
ACTIVITY_CHANNEL = "memory-activity"
|
def get_api_url() -> str:
|
||||||
DISCOVERY_CHANNEL = "memory-discoveries"
|
"""Get the Discord API server URL"""
|
||||||
CHAT_CHANNEL = "memory-chat"
|
host = settings.DISCORD_COLLECTOR_SERVER_URL
|
||||||
|
port = settings.DISCORD_COLLECTOR_PORT
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
class DiscordServer(requests.Session):
|
def send_dm(user_identifier: str, message: str) -> bool:
|
||||||
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
|
"""Send a DM via the Discord collector API"""
|
||||||
self.server_id = server_id
|
|
||||||
self.server_name = server_name
|
|
||||||
self.channels = {}
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.setup_channels()
|
|
||||||
self.members = self.fetch_all_members()
|
|
||||||
|
|
||||||
def setup_channels(self):
|
|
||||||
resp = self.get(self.channels_url)
|
|
||||||
resp.raise_for_status()
|
|
||||||
channels = {channel["name"]: channel["id"] for channel in resp.json()}
|
|
||||||
|
|
||||||
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
|
|
||||||
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
|
|
||||||
self.channels[ERROR_CHANNEL] = error_channel
|
|
||||||
|
|
||||||
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
|
|
||||||
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
|
|
||||||
self.channels[ACTIVITY_CHANNEL] = activity_channel
|
|
||||||
|
|
||||||
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
|
|
||||||
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
|
|
||||||
self.channels[DISCOVERY_CHANNEL] = discovery_channel
|
|
||||||
|
|
||||||
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
|
|
||||||
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
|
|
||||||
self.channels[CHAT_CHANNEL] = chat_channel
|
|
||||||
|
|
||||||
@property
|
|
||||||
def error_channel(self) -> str:
|
|
||||||
return self.channels[ERROR_CHANNEL]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def activity_channel(self) -> str:
|
|
||||||
return self.channels[ACTIVITY_CHANNEL]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def discovery_channel(self) -> str:
|
|
||||||
return self.channels[DISCOVERY_CHANNEL]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def chat_channel(self) -> str:
|
|
||||||
return self.channels[CHAT_CHANNEL]
|
|
||||||
|
|
||||||
def channel_id(self, channel_name: str) -> str:
|
|
||||||
if not (channel_id := self.channels.get(channel_name)):
|
|
||||||
raise ValueError(f"Channel {channel_name} not found")
|
|
||||||
return channel_id
|
|
||||||
|
|
||||||
def send_message(self, channel_id: str, content: str):
|
|
||||||
payload: dict[str, Any] = {"content": content}
|
|
||||||
mentions = re.findall(r"@(\S*)", content)
|
|
||||||
users = {u: i for u, i in self.members.items() if u in mentions}
|
|
||||||
if users:
|
|
||||||
for u, i in users.items():
|
|
||||||
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
|
|
||||||
payload["allowed_mentions"] = {
|
|
||||||
"parse": [],
|
|
||||||
"users": list(users.values()),
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.post(
|
|
||||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
|
||||||
json=payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
|
||||||
resp = self.post(
|
|
||||||
self.channels_url, json={"name": channel_name, "type": channel_type}
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()["id"]
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return (
|
|
||||||
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def request(self, method: str, url: str, **kwargs):
|
|
||||||
headers = kwargs.get("headers", {})
|
|
||||||
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
kwargs["headers"] = headers
|
|
||||||
return super().request(method, url, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def channels_url(self) -> str:
|
|
||||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def members_url(self) -> str:
|
|
||||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dm_create_url(self) -> str:
|
|
||||||
return "https://discord.com/api/v10/users/@me/channels"
|
|
||||||
|
|
||||||
def list_members(
|
|
||||||
self, limit: int = 1000, after: str | None = None
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""List up to `limit` members in this guild, starting after a user ID.
|
|
||||||
|
|
||||||
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
|
|
||||||
"""
|
|
||||||
params: dict[str, Any] = {"limit": limit}
|
|
||||||
if after:
|
|
||||||
params["after"] = after
|
|
||||||
resp = self.get(self.members_url, params=params)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
|
|
||||||
"""Retrieve all members in the guild by paginating the members list.
|
|
||||||
|
|
||||||
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
|
|
||||||
"""
|
|
||||||
members: dict[str, str] = {}
|
|
||||||
after: str | None = None
|
|
||||||
while batch := self.list_members(limit=page_size, after=after):
|
|
||||||
for member in batch:
|
|
||||||
user = member.get("user", {})
|
|
||||||
members[user.get("global_name") or user.get("username", "")] = user.get(
|
|
||||||
"id", ""
|
|
||||||
)
|
|
||||||
after = user.get("id", "")
|
|
||||||
return members
|
|
||||||
|
|
||||||
def create_dm_channel(self, user_id: str) -> str:
|
|
||||||
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
|
|
||||||
|
|
||||||
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
|
|
||||||
"""
|
|
||||||
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
return data["id"]
|
|
||||||
|
|
||||||
def send_dm(self, user_id: str, content: str):
|
|
||||||
"""Send a direct message to a specific user by ID."""
|
|
||||||
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
|
|
||||||
return self.post(
|
|
||||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
|
||||||
json={"content": content},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_bot_servers() -> list[dict[str, Any]]:
|
|
||||||
"""Get list of servers the bot is in."""
|
|
||||||
if not settings.DISCORD_BOT_TOKEN:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
|
response = requests.post(
|
||||||
response = requests.get(
|
f"{get_api_url()}/send_dm",
|
||||||
"https://discord.com/api/v10/users/@me/guilds", headers=headers
|
json={"user_identifier": user_identifier, "message": message},
|
||||||
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_message(channel_name: str, message: str) -> bool:
|
||||||
|
"""Send a message to a channel via the Discord collector API"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{get_api_url()}/send_channel",
|
||||||
|
json={"channel_name": channel_name, "message": message},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_collector_healthy() -> bool:
|
||||||
|
"""Check if the Discord collector is running and healthy"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("status") == "healthy"
|
||||||
|
|
||||||
|
except requests.RequestException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_discord_metadata() -> dict[str, int] | None:
|
||||||
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
except Exception as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to get bot servers: {e}")
|
logger.error(f"Failed to refresh Discord metadata: {e}")
|
||||||
return []
|
return None
|
||||||
|
|
||||||
|
|
||||||
servers: dict[str, DiscordServer] = {}
|
# Convenience functions
|
||||||
|
def send_error_message(message: str) -> bool:
|
||||||
|
"""Send an error message to the error channel"""
|
||||||
|
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def load_servers():
|
def send_activity_message(message: str) -> bool:
|
||||||
for server in get_bot_servers():
|
"""Send an activity message to the activity channel"""
|
||||||
servers[server["id"]] = DiscordServer(server["id"], server["name"])
|
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(channel: str, message: str):
|
def send_discovery_message(message: str) -> bool:
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
"""Send a discovery message to the discovery channel"""
|
||||||
return
|
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||||
|
|
||||||
for server in servers.values():
|
|
||||||
server.send_message(server.channel_id(channel), message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_error_message(message: str):
|
def send_chat_message(message: str) -> bool:
|
||||||
broadcast_message(ERROR_CHANNEL, message)
|
"""Send a chat message to the chat channel"""
|
||||||
|
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
|
||||||
|
|
||||||
def send_activity_message(message: str):
|
|
||||||
broadcast_message(ACTIVITY_CHANNEL, message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_discovery_message(message: str):
|
|
||||||
broadcast_message(DISCOVERY_CHANNEL, message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_chat_message(message: str):
|
|
||||||
broadcast_message(CHAT_CHANNEL, message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_dm(user_id: str, message: str):
|
|
||||||
for server in servers.values():
|
|
||||||
if not server.members.get(user_id) and user_id not in server.members.values():
|
|
||||||
continue
|
|
||||||
|
|
||||||
server.send_dm(user_id, message)
|
|
||||||
|
|
||||||
|
|
||||||
def notify_task_failure(
|
def notify_task_failure(
|
||||||
@ -234,9 +114,6 @@ def notify_task_failure(
|
|||||||
task_args: Task arguments
|
task_args: Task arguments
|
||||||
task_kwargs: Task keyword arguments
|
task_kwargs: Task keyword arguments
|
||||||
traceback_str: Full traceback string
|
traceback_str: Full traceback string
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if notification sent successfully
|
|
||||||
"""
|
"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
logger.debug("Discord notifications disabled")
|
logger.debug("Discord notifications disabled")
|
||||||
|
|||||||
@ -1,122 +0,0 @@
|
|||||||
import logging
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
from typing import Any
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from memory.common import settings, tokens
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """
|
|
||||||
You are a helpful assistant that creates concise summaries and identifies key topics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image(image: Image.Image) -> str:
|
|
||||||
"""Encode PIL Image to base64 string."""
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
# Convert to RGB if necessary (for RGBA, etc.)
|
|
||||||
if image.mode != "RGB":
|
|
||||||
image = image.convert("RGB")
|
|
||||||
image.save(buffer, format="JPEG")
|
|
||||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def call_openai(
|
|
||||||
prompt: str,
|
|
||||||
model: str,
|
|
||||||
images: list[Image.Image] = [],
|
|
||||||
system_prompt: str = SYSTEM_PROMPT,
|
|
||||||
) -> str:
|
|
||||||
"""Call OpenAI API for summarization."""
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
|
||||||
try:
|
|
||||||
user_content: Any = [{"type": "text", "text": prompt}]
|
|
||||||
if images:
|
|
||||||
for image in images:
|
|
||||||
encoded_image = encode_image(image)
|
|
||||||
user_content.append(
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=model.split("/")[1],
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt,
|
|
||||||
},
|
|
||||||
{"role": "user", "content": user_content},
|
|
||||||
],
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return response.choices[0].message.content or ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"OpenAI API error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def call_anthropic(
|
|
||||||
prompt: str,
|
|
||||||
model: str,
|
|
||||||
images: list[Image.Image] = [],
|
|
||||||
system_prompt: str = SYSTEM_PROMPT,
|
|
||||||
) -> str:
|
|
||||||
"""Call Anthropic API for summarization."""
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
|
||||||
try:
|
|
||||||
# Prepare the message content
|
|
||||||
content: Any = [{"type": "text", "text": prompt}]
|
|
||||||
if images:
|
|
||||||
# Add images if provided
|
|
||||||
for image in images:
|
|
||||||
encoded_image = encode_image(image)
|
|
||||||
content.append(
|
|
||||||
{ # type: ignore
|
|
||||||
"type": "image",
|
|
||||||
"source": {
|
|
||||||
"type": "base64",
|
|
||||||
"media_type": "image/jpeg",
|
|
||||||
"data": encoded_image,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.messages.create(
|
|
||||||
model=model.split("/")[1],
|
|
||||||
messages=[{"role": "user", "content": content}], # type: ignore
|
|
||||||
system=system_prompt,
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return response.content[0].text
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Anthropic API error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def call(
|
|
||||||
prompt: str,
|
|
||||||
model: str,
|
|
||||||
images: list[Image.Image] = [],
|
|
||||||
system_prompt: str = SYSTEM_PROMPT,
|
|
||||||
) -> str:
|
|
||||||
if model.startswith("anthropic"):
|
|
||||||
return call_anthropic(prompt, model, images, system_prompt)
|
|
||||||
return call_openai(prompt, model, images, system_prompt)
|
|
||||||
|
|
||||||
|
|
||||||
def truncate(content: str, target_tokens: int) -> str:
|
|
||||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
|
||||||
if len(content) > target_chars:
|
|
||||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
|
||||||
return content
|
|
||||||
79
src/memory/common/llms/__init__.py
Normal file
79
src/memory/common/llms/__init__.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""LLM provider module for unified LLM access."""
|
||||||
|
|
||||||
|
# Legacy imports for backwards compatibility
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# New provider system
|
||||||
|
from memory.common.llms.base import (
|
||||||
|
BaseLLMProvider,
|
||||||
|
ImageContent,
|
||||||
|
LLMSettings,
|
||||||
|
Message,
|
||||||
|
MessageContent,
|
||||||
|
MessageRole,
|
||||||
|
StreamEvent,
|
||||||
|
TextContent,
|
||||||
|
ThinkingContent,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolResultContent,
|
||||||
|
ToolUseContent,
|
||||||
|
create_provider,
|
||||||
|
)
|
||||||
|
from memory.common import tokens
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseLLMProvider",
|
||||||
|
"Message",
|
||||||
|
"MessageRole",
|
||||||
|
"MessageContent",
|
||||||
|
"TextContent",
|
||||||
|
"ImageContent",
|
||||||
|
"ToolUseContent",
|
||||||
|
"ToolResultContent",
|
||||||
|
"ThinkingContent",
|
||||||
|
"ToolDefinition",
|
||||||
|
"StreamEvent",
|
||||||
|
"LLMSettings",
|
||||||
|
"create_provider",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def summarize(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = "",
|
||||||
|
) -> str:
|
||||||
|
provider = create_provider(model=model)
|
||||||
|
try:
|
||||||
|
# Build message content
|
||||||
|
content: list[MessageContent] = [TextContent(text=prompt)]
|
||||||
|
for image in images:
|
||||||
|
content.append(ImageContent(image=image))
|
||||||
|
|
||||||
|
messages = [Message(role=MessageRole.USER, content=content)]
|
||||||
|
settings_obj = LLMSettings(temperature=0.3, max_tokens=2048)
|
||||||
|
|
||||||
|
res = provider.run_with_tools(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt
|
||||||
|
or "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
||||||
|
settings=settings_obj,
|
||||||
|
tools={},
|
||||||
|
)
|
||||||
|
return res.response or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def truncate(content: str, target_tokens: int) -> str:
|
||||||
|
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||||
|
if len(content) > target_chars:
|
||||||
|
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||||
|
return content
|
||||||
451
src/memory/common/llms/anthropic_provider.py
Normal file
451
src/memory/common/llms/anthropic_provider.py
Normal file
@ -0,0 +1,451 @@
|
|||||||
|
"""Anthropic LLM provider implementation."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, AsyncIterator, Iterator, Optional
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
from memory.common.llms.base import (
|
||||||
|
BaseLLMProvider,
|
||||||
|
ImageContent,
|
||||||
|
LLMSettings,
|
||||||
|
Message,
|
||||||
|
MessageRole,
|
||||||
|
StreamEvent,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolUseContent,
|
||||||
|
ThinkingContent,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(BaseLLMProvider):
|
||||||
|
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
||||||
|
|
||||||
|
# Models that support extended thinking
|
||||||
|
THINKING_MODELS = {
|
||||||
|
"claude-opus-4",
|
||||||
|
"claude-opus-4-1",
|
||||||
|
"claude-sonnet-4-0",
|
||||||
|
"claude-sonnet-3-7",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str, enable_thinking: bool = False):
|
||||||
|
"""
|
||||||
|
Initialize the Anthropic provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Anthropic API key
|
||||||
|
model: Model identifier
|
||||||
|
enable_thinking: Enable extended thinking for supported models
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, model)
|
||||||
|
self.enable_thinking = enable_thinking
|
||||||
|
self._async_client: Optional[anthropic.AsyncAnthropic] = None
|
||||||
|
|
||||||
|
def _initialize_client(self) -> anthropic.Anthropic:
|
||||||
|
"""Initialize the Anthropic client."""
|
||||||
|
return anthropic.Anthropic(api_key=self.api_key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def async_client(self) -> anthropic.AsyncAnthropic:
|
||||||
|
"""Lazy-load the async client."""
|
||||||
|
if self._async_client is None:
|
||||||
|
self._async_client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
||||||
|
return self._async_client
|
||||||
|
|
||||||
|
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||||
|
"""Convert ImageContent to Anthropic's base64 source format."""
|
||||||
|
encoded_image = self.encode_image(content.image)
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": encoded_image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_message(self, message: Message) -> dict[str, Any]:
|
||||||
|
converted = message.to_dict()
|
||||||
|
if converted["role"] == MessageRole.ASSISTANT and isinstance(
|
||||||
|
converted["content"], list
|
||||||
|
):
|
||||||
|
content = sorted(
|
||||||
|
converted["content"], key=lambda x: x["type"] != "thinking"
|
||||||
|
)
|
||||||
|
return converted | {"content": content}
|
||||||
|
return converted
|
||||||
|
|
||||||
|
def _should_include_message(self, message: Message) -> bool:
|
||||||
|
"""Filter out system messages (handled separately in Anthropic)."""
|
||||||
|
return message.role != MessageRole.SYSTEM
|
||||||
|
|
||||||
|
def _supports_thinking(self) -> bool:
|
||||||
|
"""Check if the current model supports extended thinking."""
|
||||||
|
model_lower = self.model.lower()
|
||||||
|
return any(supported in model_lower for supported in self.THINKING_MODELS)
|
||||||
|
|
||||||
|
def _build_request_kwargs(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None,
|
||||||
|
tools: list[ToolDefinition] | None,
|
||||||
|
settings: LLMSettings,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build common request kwargs for API calls."""
|
||||||
|
anthropic_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": anthropic_messages,
|
||||||
|
"temperature": settings.temperature,
|
||||||
|
"max_tokens": settings.max_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only include top_p if explicitly set
|
||||||
|
if settings.top_p is not None:
|
||||||
|
kwargs["top_p"] = settings.top_p
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
kwargs["system"] = system_prompt
|
||||||
|
|
||||||
|
if settings.stop_sequences:
|
||||||
|
kwargs["stop_sequences"] = settings.stop_sequences
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = self._convert_tools(tools)
|
||||||
|
|
||||||
|
# Enable extended thinking if requested and model supports it
|
||||||
|
if self.enable_thinking and self._supports_thinking():
|
||||||
|
thinking_budget = min(10000, settings.max_tokens - 1024)
|
||||||
|
if thinking_budget >= 1024:
|
||||||
|
kwargs["thinking"] = {
|
||||||
|
"type": "enabled",
|
||||||
|
"budget_tokens": thinking_budget,
|
||||||
|
}
|
||||||
|
# When thinking is enabled: temperature must be 1, can't use top_p
|
||||||
|
kwargs["temperature"] = 1.0
|
||||||
|
kwargs.pop("top_p", None)
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _handle_stream_event(
|
||||||
|
self, event: Any, current_tool_use: dict[str, Any] | None
|
||||||
|
) -> tuple[StreamEvent | None, dict[str, Any] | None]:
|
||||||
|
"""
|
||||||
|
Handle a streaming event and return StreamEvent and updated tool state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
||||||
|
"""
|
||||||
|
event_type = getattr(event, "type", None)
|
||||||
|
|
||||||
|
# Handle error events
|
||||||
|
if event_type == "error":
|
||||||
|
error = getattr(event, "error", None)
|
||||||
|
error_msg = str(error) if error else "Unknown error"
|
||||||
|
return StreamEvent(type="error", data=error_msg), current_tool_use
|
||||||
|
|
||||||
|
if event_type == "content_block_start":
|
||||||
|
block = getattr(event, "content_block", None)
|
||||||
|
if not block:
|
||||||
|
return None, current_tool_use
|
||||||
|
|
||||||
|
block_type = getattr(block, "type", None)
|
||||||
|
|
||||||
|
# Handle various tool types (tool_use, mcp_tool_use, server_tool_use)
|
||||||
|
if block_type in ("tool_use", "mcp_tool_use", "server_tool_use"):
|
||||||
|
# In content_block_start, input may already be present (empty dict)
|
||||||
|
block_input = getattr(block, "input", None)
|
||||||
|
current_tool_use = {
|
||||||
|
"id": getattr(block, "id", ""),
|
||||||
|
"name": getattr(block, "name", ""),
|
||||||
|
"input": block_input if block_input is not None else "",
|
||||||
|
"server_name": getattr(block, "server_name", None),
|
||||||
|
"is_server_call": block_type != "tool_use",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle tool result blocks
|
||||||
|
elif hasattr(block, "tool_use_id"):
|
||||||
|
tool_result = {
|
||||||
|
"id": getattr(block, "tool_use_id", ""),
|
||||||
|
"result": getattr(block, "content", ""),
|
||||||
|
}
|
||||||
|
return StreamEvent(
|
||||||
|
type="tool_result", data=tool_result
|
||||||
|
), current_tool_use
|
||||||
|
|
||||||
|
# For non-tool blocks (text, thinking), we don't need to track state
|
||||||
|
return None, current_tool_use
|
||||||
|
|
||||||
|
elif event_type == "content_block_delta":
|
||||||
|
delta = getattr(event, "delta", None)
|
||||||
|
if not delta:
|
||||||
|
return None, current_tool_use
|
||||||
|
|
||||||
|
delta_type = getattr(delta, "type", None)
|
||||||
|
|
||||||
|
if delta_type == "text_delta":
|
||||||
|
text = getattr(delta, "text", "")
|
||||||
|
return StreamEvent(type="text", data=text), current_tool_use
|
||||||
|
|
||||||
|
elif delta_type == "thinking_delta":
|
||||||
|
thinking = getattr(delta, "thinking", "")
|
||||||
|
return StreamEvent(type="thinking", data=thinking), current_tool_use
|
||||||
|
|
||||||
|
elif delta_type == "signature_delta":
|
||||||
|
# Handle thinking signature for extended thinking
|
||||||
|
signature = getattr(delta, "signature", "")
|
||||||
|
return StreamEvent(
|
||||||
|
type="thinking", signature=signature
|
||||||
|
), current_tool_use
|
||||||
|
|
||||||
|
elif delta_type == "input_json_delta":
|
||||||
|
if current_tool_use is None:
|
||||||
|
# Edge case: received input_json_delta without tool_use start
|
||||||
|
logger.warning("Received input_json_delta without tool_use context")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Only accumulate if input is still a string (being built up)
|
||||||
|
if isinstance(current_tool_use.get("input"), str):
|
||||||
|
partial_json = getattr(delta, "partial_json", "")
|
||||||
|
current_tool_use["input"] += partial_json
|
||||||
|
# else: input was already set as a dict in content_block_start
|
||||||
|
|
||||||
|
return None, current_tool_use
|
||||||
|
|
||||||
|
elif event_type == "content_block_stop":
|
||||||
|
if current_tool_use:
|
||||||
|
# Use the parsed input from the content block if available
|
||||||
|
# This handles empty inputs {} more reliably than parsing
|
||||||
|
content_block = getattr(event, "content_block", None)
|
||||||
|
if content_block and hasattr(content_block, "input"):
|
||||||
|
current_tool_use["input"] = content_block.input
|
||||||
|
else:
|
||||||
|
# Fallback: parse accumulated JSON string
|
||||||
|
input_str = current_tool_use.get("input", "")
|
||||||
|
if isinstance(input_str, str):
|
||||||
|
# Need to parse the accumulated string
|
||||||
|
if not input_str or input_str.isspace():
|
||||||
|
# Empty or whitespace-only input
|
||||||
|
current_tool_use["input"] = {}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
current_tool_use["input"] = json.loads(input_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse tool input '{input_str}': {e}"
|
||||||
|
)
|
||||||
|
current_tool_use["input"] = {}
|
||||||
|
# else: input is already parsed
|
||||||
|
|
||||||
|
tool_data = {
|
||||||
|
"id": current_tool_use.get("id", ""),
|
||||||
|
"name": current_tool_use.get("name", ""),
|
||||||
|
"input": current_tool_use.get("input", {}),
|
||||||
|
}
|
||||||
|
# Include server info if present
|
||||||
|
if current_tool_use.get("server_name"):
|
||||||
|
tool_data["server_name"] = current_tool_use["server_name"]
|
||||||
|
if current_tool_use.get("is_server_call"):
|
||||||
|
tool_data["is_server_call"] = current_tool_use["is_server_call"]
|
||||||
|
|
||||||
|
return StreamEvent(type="tool_use", data=tool_data), None
|
||||||
|
|
||||||
|
elif event_type == "message_delta":
|
||||||
|
delta = getattr(event, "delta", None)
|
||||||
|
if delta:
|
||||||
|
stop_reason = getattr(delta, "stop_reason", None)
|
||||||
|
if stop_reason == "max_tokens":
|
||||||
|
return StreamEvent(
|
||||||
|
type="error", data="Max tokens reached"
|
||||||
|
), current_tool_use
|
||||||
|
|
||||||
|
# Handle token usage information
|
||||||
|
usage = getattr(event, "usage", None)
|
||||||
|
if usage:
|
||||||
|
usage_data = {
|
||||||
|
"input_tokens": getattr(usage, "input_tokens", 0),
|
||||||
|
"output_tokens": getattr(usage, "output_tokens", 0),
|
||||||
|
"cache_creation_input_tokens": getattr(
|
||||||
|
usage, "cache_creation_input_tokens", None
|
||||||
|
),
|
||||||
|
"cache_read_input_tokens": getattr(
|
||||||
|
usage, "cache_read_input_tokens", None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
# Could emit this as a separate event type if needed
|
||||||
|
logger.debug(f"Token usage: {usage_data}")
|
||||||
|
|
||||||
|
return None, current_tool_use
|
||||||
|
|
||||||
|
elif event_type == "message_stop":
|
||||||
|
# Final event - clean up any pending state
|
||||||
|
if current_tool_use:
|
||||||
|
logger.warning(
|
||||||
|
f"Message ended with incomplete tool use: {current_tool_use}"
|
||||||
|
)
|
||||||
|
return StreamEvent(type="done"), None
|
||||||
|
|
||||||
|
# Unknown event type - log but don't fail
|
||||||
|
if event_type and event_type not in (
|
||||||
|
"message_start",
|
||||||
|
"message_delta",
|
||||||
|
"content_block_start",
|
||||||
|
"content_block_delta",
|
||||||
|
"content_block_stop",
|
||||||
|
"message_stop",
|
||||||
|
):
|
||||||
|
logger.debug(f"Unknown event type: {event_type}")
|
||||||
|
|
||||||
|
return None, current_tool_use
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a non-streaming response."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.messages.create(**kwargs)
|
||||||
|
return "".join(
|
||||||
|
block.text for block in response.content if block.type == "text"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
"""Generate a streaming response."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.client.messages.stream(**kwargs) as stream:
|
||||||
|
current_tool_use: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
for event in stream:
|
||||||
|
stream_event, current_tool_use = self._handle_stream_event(
|
||||||
|
event, current_tool_use
|
||||||
|
)
|
||||||
|
if stream_event:
|
||||||
|
yield stream_event
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic streaming error: {e}")
|
||||||
|
yield StreamEvent(type="error", data=str(e))
|
||||||
|
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a non-streaming response asynchronously."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.async_client.messages.create(**kwargs)
|
||||||
|
return "".join(
|
||||||
|
block.text for block in response.content if block.type == "text"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
"""Generate a streaming response asynchronously."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.async_client.messages.stream(**kwargs) as stream:
|
||||||
|
current_tool_use: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
async for event in stream:
|
||||||
|
stream_event, current_tool_use = self._handle_stream_event(
|
||||||
|
event, current_tool_use
|
||||||
|
)
|
||||||
|
if stream_event:
|
||||||
|
yield stream_event
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic streaming error: {e}")
|
||||||
|
yield StreamEvent(type="error", data=str(e))
|
||||||
|
|
||||||
|
def stream_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: dict[str, ToolDefinition],
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
max_iterations: int = 10,
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
if max_iterations <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
response = TextContent(text="")
|
||||||
|
thinking = ThinkingContent(thinking="", signature="")
|
||||||
|
|
||||||
|
for event in self.stream(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=list(tools.values()),
|
||||||
|
settings=settings,
|
||||||
|
):
|
||||||
|
if event.type == "text":
|
||||||
|
response.text += event.data
|
||||||
|
yield event
|
||||||
|
elif event.type == "thinking" and event.signature:
|
||||||
|
thinking.signature = event.signature
|
||||||
|
elif event.type == "thinking":
|
||||||
|
thinking.thinking += event.data
|
||||||
|
yield event
|
||||||
|
elif event.type == "tool_use":
|
||||||
|
yield event
|
||||||
|
tool_result = self.execute_tool(event.data, tools)
|
||||||
|
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||||
|
messages.append(
|
||||||
|
Message.assistant(
|
||||||
|
response,
|
||||||
|
thinking,
|
||||||
|
ToolUseContent(
|
||||||
|
id=event.data["id"],
|
||||||
|
name=event.data["name"],
|
||||||
|
input=event.data["input"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages.append(Message.user(tool_result=tool_result))
|
||||||
|
yield from self.stream_with_tools(
|
||||||
|
messages, tools, settings, system_prompt, max_iterations - 1
|
||||||
|
)
|
||||||
|
elif event.type == "tool_result":
|
||||||
|
yield event
|
||||||
|
elif event.type == "error":
|
||||||
|
logger.error(f"LLM error: {event.data}")
|
||||||
|
raise RuntimeError(f"LLM error: {event.data}")
|
||||||
561
src/memory/common/llms/base.py
Normal file
561
src/memory/common/llms/base.py
Normal file
@ -0,0 +1,561 @@
|
|||||||
|
"""Base classes and types for LLM providers."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, Enum):
|
||||||
|
"""Message roles for chat history."""
|
||||||
|
|
||||||
|
SYSTEM = "system"
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextContent:
|
||||||
|
"""Text content in a message."""
|
||||||
|
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str = ""
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {"type": "text", "text": self.text}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageContent:
|
||||||
|
"""Image content in a message."""
|
||||||
|
|
||||||
|
type: Literal["image"] = "image"
|
||||||
|
image: Image.Image = None # type: ignore
|
||||||
|
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
# Note: Image will be encoded by provider-specific implementation
|
||||||
|
return {"type": "image", "image": self.image}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolUseContent:
|
||||||
|
"""Tool use request from the assistant."""
|
||||||
|
|
||||||
|
type: Literal["tool_use"] = "tool_use"
|
||||||
|
id: str = ""
|
||||||
|
name: str = ""
|
||||||
|
input: dict[str, Any] = None # type: ignore
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"input": self.input,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResultContent:
|
||||||
|
"""Tool result from tool execution."""
|
||||||
|
|
||||||
|
type: Literal["tool_result"] = "tool_result"
|
||||||
|
tool_use_id: str = ""
|
||||||
|
content: str = ""
|
||||||
|
is_error: bool = False
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": self.tool_use_id,
|
||||||
|
"content": self.content,
|
||||||
|
"is_error": self.is_error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThinkingContent:
|
||||||
|
"""Thinking/reasoning content from the assistant (extended thinking)."""
|
||||||
|
|
||||||
|
type: Literal["thinking"] = "thinking"
|
||||||
|
thinking: str = ""
|
||||||
|
signature: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": self.thinking,
|
||||||
|
"signature": self.signature,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MessageContent = Union[
|
||||||
|
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Turn:
|
||||||
|
"""A turn in the conversation."""
|
||||||
|
|
||||||
|
response: str | None
|
||||||
|
thinking: str | None
|
||||||
|
tool_calls: dict[str, ToolResult] | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
"""A message in the conversation history."""
|
||||||
|
|
||||||
|
role: MessageRole
|
||||||
|
content: Union[str, list[MessageContent]]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert message to dictionary format."""
|
||||||
|
if isinstance(self.content, str):
|
||||||
|
return {"role": self.role.value, "content": self.content}
|
||||||
|
content_list = [item.to_dict() for item in self.content]
|
||||||
|
return {"role": self.role.value, "content": content_list}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def assistant(
|
||||||
|
text: TextContent | None = None,
|
||||||
|
thinking: ThinkingContent | None = None,
|
||||||
|
tool_use: ToolUseContent | None = None,
|
||||||
|
) -> "Message":
|
||||||
|
parts = []
|
||||||
|
if text:
|
||||||
|
parts.append(text)
|
||||||
|
if thinking:
|
||||||
|
parts.append(thinking)
|
||||||
|
if tool_use:
|
||||||
|
parts.append(tool_use)
|
||||||
|
return Message(role=MessageRole.ASSISTANT, content=parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def user(
|
||||||
|
text: str | None = None, tool_result: ToolResultContent | None = None
|
||||||
|
) -> "Message":
|
||||||
|
parts = []
|
||||||
|
if text:
|
||||||
|
parts.append(TextContent(text=text))
|
||||||
|
if tool_result:
|
||||||
|
parts.append(tool_result)
|
||||||
|
return Message(role=MessageRole.USER, content=parts)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamEvent:
|
||||||
|
"""An event from the streaming response."""
|
||||||
|
|
||||||
|
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
|
||||||
|
data: Any = None
|
||||||
|
signature: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMSettings:
|
||||||
|
"""Settings for LLM API calls."""
|
||||||
|
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 2048
|
||||||
|
# Don't set by default - some models don't allow both temp and top_p
|
||||||
|
top_p: float | None = None
|
||||||
|
stop_sequences: list[str] | None = None
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMProvider(ABC):
|
||||||
|
"""Base class for LLM providers."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str):
|
||||||
|
"""
|
||||||
|
Initialize the LLM provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for the provider
|
||||||
|
model: Model identifier
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
self._client: Any = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _initialize_client(self) -> Any:
|
||||||
|
"""Initialize the provider-specific client."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> Any:
|
||||||
|
"""Lazy-load the client."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = self._initialize_client()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def execute_tool(
|
||||||
|
self,
|
||||||
|
tool_call: ToolCall,
|
||||||
|
tool_handlers: dict[str, ToolDefinition],
|
||||||
|
) -> ToolResultContent:
|
||||||
|
"""
|
||||||
|
Execute a tool call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: Tool call
|
||||||
|
tool_handlers: Dict mapping tool names to handler functions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResultContent with result or error
|
||||||
|
"""
|
||||||
|
name = tool_call.get("name")
|
||||||
|
tool_use_id = tool_call.get("id")
|
||||||
|
input = tool_call.get("input")
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
return ToolResultContent(
|
||||||
|
tool_use_id=tool_use_id,
|
||||||
|
content="Tool name missing",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (tool := tool_handlers.get(name)):
|
||||||
|
return ToolResultContent(
|
||||||
|
tool_use_id=tool_use_id,
|
||||||
|
content=f"Tool '{name}' not found",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ToolResultContent(
|
||||||
|
tool_use_id=tool_use_id,
|
||||||
|
content=tool(input),
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool '{name}' failed: {e}", exc_info=True)
|
||||||
|
return ToolResultContent(
|
||||||
|
tool_use_id=tool_use_id,
|
||||||
|
content=str(e),
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_image(self, image: Image.Image) -> str:
|
||||||
|
"""
|
||||||
|
Encode PIL Image to base64 string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image to encode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded string
|
||||||
|
"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
# Convert to RGB if necessary (for RGBA, etc.)
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image.save(buffer, format="JPEG")
|
||||||
|
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
|
||||||
|
"""Convert TextContent to provider format. Override for custom format."""
|
||||||
|
return content.to_dict()
|
||||||
|
|
||||||
|
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
|
||||||
|
"""Convert ImageContent to provider format. Override for custom format."""
|
||||||
|
return content.to_dict()
|
||||||
|
|
||||||
|
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
|
||||||
|
"""Convert ToolUseContent to provider format. Override for custom format."""
|
||||||
|
return content.to_dict()
|
||||||
|
|
||||||
|
def _convert_tool_result_content(
|
||||||
|
self, content: ToolResultContent
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Convert ToolResultContent to provider format. Override for custom format."""
|
||||||
|
return content.to_dict()
|
||||||
|
|
||||||
|
def _convert_thinking_content(self, content: ThinkingContent) -> dict[str, Any]:
|
||||||
|
"""Convert ThinkingContent to provider format. Override for custom format."""
|
||||||
|
return content.to_dict()
|
||||||
|
|
||||||
|
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a MessageContent item to provider format.
|
||||||
|
|
||||||
|
Dispatches to type-specific converters that can be overridden.
|
||||||
|
"""
|
||||||
|
if isinstance(content, TextContent):
|
||||||
|
return self._convert_text_content(content)
|
||||||
|
elif isinstance(content, ImageContent):
|
||||||
|
return self._convert_image_content(content)
|
||||||
|
elif isinstance(content, ToolUseContent):
|
||||||
|
return self._convert_tool_use_content(content)
|
||||||
|
elif isinstance(content, ToolResultContent):
|
||||||
|
return self._convert_tool_result_content(content)
|
||||||
|
elif isinstance(content, ThinkingContent):
|
||||||
|
return self._convert_thinking_content(content)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown content type: {type(content)}")
|
||||||
|
|
||||||
|
def _convert_message(self, message: Message) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a Message to provider format.
|
||||||
|
|
||||||
|
Can be overridden for provider-specific handling (e.g., filtering system messages).
|
||||||
|
"""
|
||||||
|
return message.to_dict()
|
||||||
|
|
||||||
|
def _should_include_message(self, message: Message) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if a message should be included in the request.
|
||||||
|
|
||||||
|
Override to filter messages (e.g., Anthropic filters SYSTEM messages).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if message should be included
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert a list of messages to provider format.
|
||||||
|
|
||||||
|
Uses _should_include_message for filtering and _convert_message for conversion.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
self._convert_message(msg)
|
||||||
|
for msg in messages
|
||||||
|
if self._should_include_message(msg)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a single ToolDefinition to provider format.
|
||||||
|
|
||||||
|
Default format matches Anthropic. Override for other providers (e.g., OpenAI uses functions).
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"input_schema": tool.input_schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_tools(
|
||||||
|
self, tools: list[ToolDefinition] | None
|
||||||
|
) -> Optional[list[dict[str, Any]]]:
|
||||||
|
"""Convert tool definitions to provider format."""
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
return [self._convert_tool(tool) for tool in tools]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a non-streaming response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
tools: Optional list of tools the LLM can use
|
||||||
|
settings: Optional settings for the generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated text response
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
"""
|
||||||
|
Generate a streaming response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
tools: Optional list of tools the LLM can use
|
||||||
|
settings: Optional settings for the generation
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamEvent objects containing text chunks, tool uses, or errors
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a non-streaming response asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
tools: Optional list of tools the LLM can use
|
||||||
|
settings: Optional settings for the generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated text response
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
"""
|
||||||
|
Generate a streaming response asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
tools: Optional list of tools the LLM can use
|
||||||
|
settings: Optional settings for the generation
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamEvent objects containing text chunks, tool uses, or errors
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def stream_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: dict[str, ToolDefinition],
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
max_iterations: int = 10,
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: dict[str, ToolDefinition],
|
||||||
|
settings: LLMSettings | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
max_iterations: int = 10,
|
||||||
|
) -> Turn:
|
||||||
|
thinking, response, tool_calls = "", "", {}
|
||||||
|
for event in self.stream_with_tools(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
settings=settings,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
):
|
||||||
|
if event.type == "thinking":
|
||||||
|
thinking += event.data
|
||||||
|
elif event.type == "tool_use":
|
||||||
|
tool_calls[event.data["id"]] = {
|
||||||
|
"name": event.data["name"],
|
||||||
|
"input": event.data["input"],
|
||||||
|
"output": "",
|
||||||
|
}
|
||||||
|
elif event.type == "text":
|
||||||
|
response += event.data
|
||||||
|
elif event.type == "tool_result":
|
||||||
|
current = tool_calls.get(event.data["tool_use_id"]) or {}
|
||||||
|
tool_calls[event.data["tool_use_id"]] = {
|
||||||
|
"name": event.data.get("name") or current.get("name"),
|
||||||
|
"input": event.data.get("input") or current.get("input"),
|
||||||
|
"output": event.data.get("content"),
|
||||||
|
}
|
||||||
|
return Turn(
|
||||||
|
thinking=thinking or None,
|
||||||
|
response=response or None,
|
||||||
|
tool_calls=tool_calls or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_provider(
|
||||||
|
model: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
enable_thinking: bool = False,
|
||||||
|
) -> BaseLLMProvider:
|
||||||
|
"""
|
||||||
|
Create an LLM provider based on the model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier (e.g., "claude-3-opus-20240229", "gpt-4").
|
||||||
|
If not provided, uses SUMMARIZER_MODEL from settings.
|
||||||
|
api_key: Optional API key. If not provided, uses keys from settings.
|
||||||
|
enable_thinking: Enable extended thinking for supported models (Claude Opus 4+, Sonnet 4+, Sonnet 3.7)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An initialized LLM provider
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provider cannot be determined from the model name
|
||||||
|
"""
|
||||||
|
# Use default model from settings if not provided
|
||||||
|
if model is None:
|
||||||
|
model = settings.SUMMARIZER_MODEL
|
||||||
|
|
||||||
|
provider, model = model.split("/", 1)
|
||||||
|
|
||||||
|
if provider == "anthropic":
|
||||||
|
# Anthropic models
|
||||||
|
if api_key is None:
|
||||||
|
api_key = settings.ANTHROPIC_API_KEY
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"ANTHROPIC_API_KEY not found in settings. "
|
||||||
|
"Please set it in your .env file."
|
||||||
|
)
|
||||||
|
|
||||||
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
|
|
||||||
|
return AnthropicProvider(
|
||||||
|
api_key=api_key, model=model, enable_thinking=enable_thinking
|
||||||
|
)
|
||||||
|
|
||||||
|
# Could add OpenAI support here in the future
|
||||||
|
# elif "gpt" in model_lower or model.startswith("openai"):
|
||||||
|
# ...
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider for model: {model}. "
|
||||||
|
f"Supported providers: Anthropic (claude-*)"
|
||||||
|
)
|
||||||
388
src/memory/common/llms/openai_provider.py
Normal file
388
src/memory/common/llms/openai_provider.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
"""OpenAI LLM provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, AsyncIterator, Iterator, Optional
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from memory.common.llms.base import (
|
||||||
|
BaseLLMProvider,
|
||||||
|
ImageContent,
|
||||||
|
LLMSettings,
|
||||||
|
Message,
|
||||||
|
MessageContent,
|
||||||
|
MessageRole,
|
||||||
|
StreamEvent,
|
||||||
|
TextContent,
|
||||||
|
ThinkingContent,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolResultContent,
|
||||||
|
ToolUseContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseLLMProvider):
|
||||||
|
"""OpenAI LLM provider with streaming and tool support."""
|
||||||
|
|
||||||
|
def _initialize_client(self) -> openai.OpenAI:
|
||||||
|
"""Initialize the OpenAI client."""
|
||||||
|
return openai.OpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert our Message format to OpenAI format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages in our format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages in OpenAI format
|
||||||
|
"""
|
||||||
|
openai_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
openai_messages.append({"role": msg.role.value, "content": msg.content})
|
||||||
|
else:
|
||||||
|
# Handle multi-part content
|
||||||
|
content_parts = []
|
||||||
|
for item in msg.content:
|
||||||
|
if isinstance(item, TextContent):
|
||||||
|
content_parts.append({"type": "text", "text": item.text})
|
||||||
|
elif isinstance(item, ImageContent):
|
||||||
|
encoded_image = self.encode_image(item.image)
|
||||||
|
image_part: dict[str, Any] = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{encoded_image}"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if item.detail:
|
||||||
|
image_part["image_url"]["detail"] = item.detail
|
||||||
|
content_parts.append(image_part)
|
||||||
|
elif isinstance(item, ToolUseContent):
|
||||||
|
# OpenAI doesn't have tool_use in content, it's a separate field
|
||||||
|
# We'll handle this by adding a tool_calls field to the message
|
||||||
|
pass
|
||||||
|
elif isinstance(item, ToolResultContent):
|
||||||
|
# OpenAI handles tool results as separate "tool" role messages
|
||||||
|
openai_messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": item.tool_use_id,
|
||||||
|
"content": item.content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif isinstance(item, ThinkingContent):
|
||||||
|
# OpenAI doesn't have native thinking support in most models
|
||||||
|
# We can add it as text with a special marker
|
||||||
|
content_parts.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[Thinking: {item.thinking}]",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this message has tool calls
|
||||||
|
tool_calls = [
|
||||||
|
item for item in msg.content if isinstance(item, ToolUseContent)
|
||||||
|
]
|
||||||
|
|
||||||
|
message_dict: dict[str, Any] = {"role": msg.role.value}
|
||||||
|
|
||||||
|
if content_parts:
|
||||||
|
message_dict["content"] = content_parts
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
message_dict["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": tc.name, "arguments": str(tc.input)},
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
openai_messages.append(message_dict)
|
||||||
|
|
||||||
|
return openai_messages
|
||||||
|
|
||||||
|
def _convert_tools(
|
||||||
|
self, tools: Optional[list[ToolDefinition]]
|
||||||
|
) -> Optional[list[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Convert our tool definitions to OpenAI format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tool definitions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tools in OpenAI format
|
||||||
|
"""
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.input_schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[list[ToolDefinition]] = None,
|
||||||
|
settings: Optional[LLMSettings] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a non-streaming response."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
|
||||||
|
openai_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add system prompt as first message if provided
|
||||||
|
if system_prompt:
|
||||||
|
openai_messages.insert(
|
||||||
|
0, {"role": "system", "content": system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": openai_messages,
|
||||||
|
"temperature": settings.temperature,
|
||||||
|
"max_tokens": settings.max_tokens,
|
||||||
|
"top_p": settings.top_p,
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.stop_sequences:
|
||||||
|
kwargs["stop"] = settings.stop_sequences
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = self._convert_tools(tools)
|
||||||
|
kwargs["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[list[ToolDefinition]] = None,
|
||||||
|
settings: Optional[LLMSettings] = None,
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
"""Generate a streaming response."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
|
||||||
|
openai_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add system prompt as first message if provided
|
||||||
|
if system_prompt:
|
||||||
|
openai_messages.insert(
|
||||||
|
0, {"role": "system", "content": system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": openai_messages,
|
||||||
|
"temperature": settings.temperature,
|
||||||
|
"max_tokens": settings.max_tokens,
|
||||||
|
"top_p": settings.top_p,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.stop_sequences:
|
||||||
|
kwargs["stop"] = settings.stop_sequences
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = self._convert_tools(tools)
|
||||||
|
kwargs["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = self.client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
current_tool_call: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
|
||||||
|
# Handle text content
|
||||||
|
if delta.content:
|
||||||
|
yield StreamEvent(type="text", data=delta.content)
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tool_call in delta.tool_calls:
|
||||||
|
if tool_call.id:
|
||||||
|
# New tool call starting
|
||||||
|
if current_tool_call:
|
||||||
|
# Yield the previous one
|
||||||
|
yield StreamEvent(
|
||||||
|
type="tool_use", data=current_tool_call
|
||||||
|
)
|
||||||
|
current_tool_call = {
|
||||||
|
"id": tool_call.id,
|
||||||
|
"name": tool_call.function.name or "",
|
||||||
|
"arguments": tool_call.function.arguments or "",
|
||||||
|
}
|
||||||
|
elif current_tool_call and tool_call.function.arguments:
|
||||||
|
# Continue building the current tool call
|
||||||
|
current_tool_call["arguments"] += (
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if stream is finished
|
||||||
|
if chunk.choices[0].finish_reason:
|
||||||
|
if current_tool_call:
|
||||||
|
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||||
|
current_tool_call = None
|
||||||
|
|
||||||
|
yield StreamEvent(type="done")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI streaming error: {e}")
|
||||||
|
yield StreamEvent(type="error", data=str(e))
|
||||||
|
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[list[ToolDefinition]] = None,
|
||||||
|
settings: Optional[LLMSettings] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a non-streaming response asynchronously."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
|
||||||
|
# Use async client
|
||||||
|
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
|
openai_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add system prompt as first message if provided
|
||||||
|
if system_prompt:
|
||||||
|
openai_messages.insert(
|
||||||
|
0, {"role": "system", "content": system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": openai_messages,
|
||||||
|
"temperature": settings.temperature,
|
||||||
|
"max_tokens": settings.max_tokens,
|
||||||
|
"top_p": settings.top_p,
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.stop_sequences:
|
||||||
|
kwargs["stop"] = settings.stop_sequences
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = self._convert_tools(tools)
|
||||||
|
kwargs["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.chat.completions.create(**kwargs)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[list[ToolDefinition]] = None,
|
||||||
|
settings: Optional[LLMSettings] = None,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
"""Generate a streaming response asynchronously."""
|
||||||
|
settings = settings or LLMSettings()
|
||||||
|
|
||||||
|
# Use async client
|
||||||
|
async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
|
openai_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add system prompt as first message if provided
|
||||||
|
if system_prompt:
|
||||||
|
openai_messages.insert(
|
||||||
|
0, {"role": "system", "content": system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": openai_messages,
|
||||||
|
"temperature": settings.temperature,
|
||||||
|
"max_tokens": settings.max_tokens,
|
||||||
|
"top_p": settings.top_p,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.stop_sequences:
|
||||||
|
kwargs["stop"] = settings.stop_sequences
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = self._convert_tools(tools)
|
||||||
|
kwargs["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await async_client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
current_tool_call: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
|
||||||
|
# Handle text content
|
||||||
|
if delta.content:
|
||||||
|
yield StreamEvent(type="text", data=delta.content)
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tool_call in delta.tool_calls:
|
||||||
|
if tool_call.id:
|
||||||
|
# New tool call starting
|
||||||
|
if current_tool_call:
|
||||||
|
# Yield the previous one
|
||||||
|
yield StreamEvent(
|
||||||
|
type="tool_use", data=current_tool_call
|
||||||
|
)
|
||||||
|
current_tool_call = {
|
||||||
|
"id": tool_call.id,
|
||||||
|
"name": tool_call.function.name or "",
|
||||||
|
"arguments": tool_call.function.arguments or "",
|
||||||
|
}
|
||||||
|
elif current_tool_call and tool_call.function.arguments:
|
||||||
|
# Continue building the current tool call
|
||||||
|
current_tool_call["arguments"] += (
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if stream is finished
|
||||||
|
if chunk.choices[0].finish_reason:
|
||||||
|
if current_tool_call:
|
||||||
|
yield StreamEvent(type="tool_use", data=current_tool_call)
|
||||||
|
current_tool_call = None
|
||||||
|
|
||||||
|
yield StreamEvent(type="done")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI streaming error: {e}")
|
||||||
|
yield StreamEvent(type="error", data=str(e))
|
||||||
36
src/memory/common/llms/tools/__init__.py
Normal file
36
src/memory/common/llms/tools/__init__.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
ToolInput = str | dict[str, Any] | None
|
||||||
|
ToolHandler = Callable[[ToolInput], str]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(TypedDict):
|
||||||
|
"""A call to a tool."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
id: str
|
||||||
|
input: ToolInput
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResult(TypedDict):
|
||||||
|
"""A result from a tool call."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
input: ToolInput
|
||||||
|
output: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolDefinition:
|
||||||
|
"""Definition of a tool that can be called by the LLM."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any] # JSON Schema for the tool's parameters
|
||||||
|
function: ToolHandler
|
||||||
|
|
||||||
|
def __call__(self, input: ToolInput) -> str:
|
||||||
|
return self.function(input)
|
||||||
42
src/memory/common/llms/tools/ping.py
Normal file
42
src/memory/common/llms/tools/ping.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""Ping tool for testing LLM tool integration."""
|
||||||
|
|
||||||
|
from memory.common.llms.tools import ToolDefinition, ToolInput
|
||||||
|
|
||||||
|
|
||||||
|
def handle_ping_call(message: ToolInput = None) -> str:
|
||||||
|
"""
|
||||||
|
Handle a ping tool call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Optional message to include in response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response string
|
||||||
|
"""
|
||||||
|
if message:
|
||||||
|
return f"pong: {message}"
|
||||||
|
return "pong"
|
||||||
|
|
||||||
|
|
||||||
|
def get_ping_tool() -> ToolDefinition:
|
||||||
|
"""
|
||||||
|
Get a ping tool definition for testing tool calls.
|
||||||
|
|
||||||
|
Returns a simple tool that takes no required parameters and can be used
|
||||||
|
to verify that tool calling is working correctly.
|
||||||
|
"""
|
||||||
|
return ToolDefinition(
|
||||||
|
name="ping",
|
||||||
|
description="A simple test tool that returns 'pong'. Use this to verify tool calling is working.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional message to echo back",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
function=handle_ping_call,
|
||||||
|
)
|
||||||
9
src/memory/common/messages.py
Normal file
9
src/memory/common/messages.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
def process_message(
|
||||||
|
msg: str,
|
||||||
|
history: list[str],
|
||||||
|
model: str | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
allowed_tools: list[str] | None = None,
|
||||||
|
disallowed_tools: list[str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
return "asd"
|
||||||
@ -172,3 +172,12 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
|||||||
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
||||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||||
)
|
)
|
||||||
|
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
|
||||||
|
# Discord collector settings
|
||||||
|
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
|
||||||
|
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
|
||||||
|
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
|
||||||
|
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8000))
|
||||||
|
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
|
||||||
|
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
||||||
|
|||||||
@ -105,7 +105,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
response = llms.summarize(prompt, settings.SUMMARIZER_MODEL)
|
||||||
result = parse_response(response)
|
result = parse_response(response)
|
||||||
|
|
||||||
summary = result.get("summary", "")
|
summary = result.get("summary", "")
|
||||||
|
|||||||
166
src/memory/discord/api.py
Normal file
166
src/memory/discord/api.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Discord API server.
|
||||||
|
|
||||||
|
FastAPI server that owns and manages a Discord collector instance,
|
||||||
|
providing HTTP endpoints for sending Discord messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.discord.collector import MessageCollector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SendDMRequest(BaseModel):
|
||||||
|
user: str # Discord user ID or username
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class SendChannelRequest(BaseModel):
|
||||||
|
channel_name: str # Channel name (e.g., "memory-errors")
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# Application state
|
||||||
|
class AppState:
|
||||||
|
def __init__(self):
|
||||||
|
self.collector: MessageCollector | None = None
|
||||||
|
self.collector_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
app_state = AppState()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Manage Discord collector lifecycle"""
|
||||||
|
if not settings.DISCORD_BOT_TOKEN:
|
||||||
|
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create and start the collector
|
||||||
|
app_state.collector = MessageCollector()
|
||||||
|
app_state.collector_task = asyncio.create_task(
|
||||||
|
app_state.collector.start(settings.DISCORD_BOT_TOKEN)
|
||||||
|
)
|
||||||
|
logger.info("Discord collector started")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if app_state.collector and not app_state.collector.is_closed():
|
||||||
|
await app_state.collector.close()
|
||||||
|
|
||||||
|
if app_state.collector_task:
|
||||||
|
app_state.collector_task.cancel()
|
||||||
|
try:
|
||||||
|
await app_state.collector_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("Discord collector stopped")
|
||||||
|
|
||||||
|
|
||||||
|
# FastAPI app with lifespan management
|
||||||
|
app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/send_dm")
|
||||||
|
async def send_dm_endpoint(request: SendDMRequest):
|
||||||
|
"""Send a DM via the collector's Discord client"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await app_state.collector.send_dm(request.user, request.message)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to send DM to {request.user}",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"DM sent to {request.user}",
|
||||||
|
"user": request.user,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send DM: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/send_channel")
|
||||||
|
async def send_channel_endpoint(request: SendChannelRequest):
|
||||||
|
"""Send a message to a channel via the collector's Discord client"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await app_state.collector.send_to_channel(
|
||||||
|
request.channel_name, request.message
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Message sent to channel {request.channel_name}",
|
||||||
|
"channel": request.channel_name,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to send message to channel {request.channel_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send channel message: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Check if the Discord collector is running and healthy"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
collector = app_state.collector
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"connected": not collector.is_closed(),
|
||||||
|
"user": str(collector.user) if collector.user else None,
|
||||||
|
"guilds": len(collector.guilds) if collector.guilds else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/refresh_metadata")
|
||||||
|
async def refresh_metadata():
|
||||||
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await app_state.collector.refresh_metadata()
|
||||||
|
return {"success": True, "message": "Metadata refreshed successfully", **result}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to refresh metadata: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
||||||
|
"""Run the Discord API server"""
|
||||||
|
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# For testing the API server standalone
|
||||||
|
host = settings.DISCORD_COLLECTOR_SERVER_URL
|
||||||
|
port = settings.DISCORD_COLLECTOR_PORT
|
||||||
|
run_discord_api_server(host, port)
|
||||||
410
src/memory/discord/collector.py
Normal file
410
src/memory/discord/collector.py
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
"""
|
||||||
|
Discord message collector.
|
||||||
|
|
||||||
|
Core message collection functionality - stores Discord messages to database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models.sources import (
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
|
)
|
||||||
|
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Pure functions for Discord entity creation/updates
|
||||||
|
def create_or_update_server(
|
||||||
|
session: Session | scoped_session, guild: discord.Guild | None
|
||||||
|
) -> DiscordServer | None:
|
||||||
|
"""Get or create DiscordServer record (pure DB operation)"""
|
||||||
|
if not guild:
|
||||||
|
return None
|
||||||
|
|
||||||
|
server = session.query(DiscordServer).get(guild.id)
|
||||||
|
|
||||||
|
if not server:
|
||||||
|
server = DiscordServer(
|
||||||
|
id=guild.id,
|
||||||
|
name=guild.name,
|
||||||
|
description=guild.description,
|
||||||
|
member_count=guild.member_count,
|
||||||
|
)
|
||||||
|
session.add(server)
|
||||||
|
session.flush() # Get the ID
|
||||||
|
logger.info(f"Created server record for {guild.name} ({guild.id})")
|
||||||
|
else:
|
||||||
|
# Update metadata
|
||||||
|
server.name = guild.name
|
||||||
|
server.description = guild.description
|
||||||
|
server.member_count = guild.member_count
|
||||||
|
server.last_sync_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
def determine_channel_metadata(channel) -> tuple[str, int | None, str]:
|
||||||
|
"""Pure function to determine channel type, server_id, and name"""
|
||||||
|
if isinstance(channel, discord.DMChannel):
|
||||||
|
desc = (
|
||||||
|
f"DM with {channel.recipient.name}" if channel.recipient else "Unknown DM"
|
||||||
|
)
|
||||||
|
return ("dm", None, desc)
|
||||||
|
elif isinstance(channel, discord.GroupChannel):
|
||||||
|
return "group_dm", None, channel.name or "Group DM"
|
||||||
|
elif isinstance(
|
||||||
|
channel, (discord.TextChannel, discord.VoiceChannel, discord.Thread)
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
channel.__class__.__name__.lower().replace("channel", ""),
|
||||||
|
channel.guild.id,
|
||||||
|
channel.name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
guild = getattr(channel, "guild", None)
|
||||||
|
server_id = guild.id if guild else None
|
||||||
|
name = getattr(channel, "name", f"Unknown-{channel.id}")
|
||||||
|
return "unknown", server_id, name
|
||||||
|
|
||||||
|
|
||||||
|
def create_or_update_channel(
|
||||||
|
session: Session | scoped_session, channel
|
||||||
|
) -> DiscordChannel | None:
|
||||||
|
"""Get or create DiscordChannel record (pure DB operation)"""
|
||||||
|
if not channel:
|
||||||
|
return None
|
||||||
|
|
||||||
|
discord_channel = session.query(DiscordChannel).get(channel.id)
|
||||||
|
|
||||||
|
if not discord_channel:
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
discord_channel = DiscordChannel(
|
||||||
|
id=channel.id,
|
||||||
|
server_id=server_id,
|
||||||
|
name=name,
|
||||||
|
channel_type=channel_type,
|
||||||
|
)
|
||||||
|
session.add(discord_channel)
|
||||||
|
session.flush()
|
||||||
|
logger.debug(f"Created channel: {name}")
|
||||||
|
elif hasattr(channel, "name"):
|
||||||
|
discord_channel.name = channel.name
|
||||||
|
|
||||||
|
return discord_channel
|
||||||
|
|
||||||
|
|
||||||
|
def create_or_update_user(
|
||||||
|
session: Session | scoped_session, user: discord.User | discord.Member
|
||||||
|
) -> DiscordUser:
|
||||||
|
"""Get or create DiscordUser record (pure DB operation)"""
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
discord_user = session.query(DiscordUser).get(user.id)
|
||||||
|
|
||||||
|
if not discord_user:
|
||||||
|
discord_user = DiscordUser(
|
||||||
|
id=user.id,
|
||||||
|
username=user.name,
|
||||||
|
display_name=user.display_name,
|
||||||
|
)
|
||||||
|
session.add(discord_user)
|
||||||
|
session.flush()
|
||||||
|
logger.debug(f"Created user: {user.name}")
|
||||||
|
else:
|
||||||
|
# Update user info in case it changed
|
||||||
|
discord_user.username = user.name
|
||||||
|
discord_user.display_name = user.display_name
|
||||||
|
|
||||||
|
return discord_user
|
||||||
|
|
||||||
|
|
||||||
|
def determine_message_metadata(
|
||||||
|
message: discord.Message,
|
||||||
|
) -> tuple[str, int | None, int | None]:
|
||||||
|
"""Pure function to determine message type, reply_to_id, and thread_id"""
|
||||||
|
message_type = "default"
|
||||||
|
reply_to_id = None
|
||||||
|
thread_id = None
|
||||||
|
|
||||||
|
if message.reference and message.reference.message_id:
|
||||||
|
message_type = "reply"
|
||||||
|
reply_to_id = message.reference.message_id
|
||||||
|
|
||||||
|
if hasattr(message.channel, "parent") and message.channel.parent:
|
||||||
|
thread_id = message.channel.id
|
||||||
|
|
||||||
|
return message_type, reply_to_id, thread_id
|
||||||
|
|
||||||
|
|
||||||
|
def should_track_message(
|
||||||
|
server: DiscordServer | None,
|
||||||
|
channel: DiscordChannel,
|
||||||
|
user: DiscordUser,
|
||||||
|
) -> bool:
|
||||||
|
"""Pure function to determine if we should track this message"""
|
||||||
|
if server and not server.track_messages: # type: ignore
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not channel.track_messages:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if channel.channel_type in ("dm", "group_dm"):
|
||||||
|
return bool(user.track_messages)
|
||||||
|
|
||||||
|
# Default: track the message
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def should_collect_bot_message(message: discord.Message) -> bool:
|
||||||
|
"""Pure function to determine if we should collect bot messages"""
|
||||||
|
return not message.author.bot or settings.DISCORD_COLLECT_BOTS
|
||||||
|
|
||||||
|
|
||||||
|
def sync_guild_metadata(guild: discord.Guild) -> None:
|
||||||
|
"""Sync a single guild's metadata (functional approach)"""
|
||||||
|
with make_session() as session:
|
||||||
|
create_or_update_server(session, guild)
|
||||||
|
|
||||||
|
for channel in guild.channels:
|
||||||
|
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||||
|
create_or_update_channel(session, channel)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class MessageCollector(commands.Bot):
|
||||||
|
"""Discord bot that collects and stores messages (thin event handler)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
intents.guilds = True
|
||||||
|
intents.members = True
|
||||||
|
intents.dm_messages = True
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
command_prefix="!memory_", # Prefix to avoid conflicts
|
||||||
|
intents=intents,
|
||||||
|
help_command=None, # Disable default help
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_ready(self):
|
||||||
|
"""Called when bot connects to Discord"""
|
||||||
|
logger.info(f"Discord collector connected as {self.user}")
|
||||||
|
logger.info(f"Connected to {len(self.guilds)} servers")
|
||||||
|
|
||||||
|
# Sync server and channel metadata
|
||||||
|
await self.sync_servers_and_channels()
|
||||||
|
|
||||||
|
logger.info("Discord message collector ready")
|
||||||
|
|
||||||
|
async def on_message(self, message: discord.Message):
|
||||||
|
"""Queue incoming message for database storage"""
|
||||||
|
try:
|
||||||
|
if should_collect_bot_message(message):
|
||||||
|
# Ensure Discord entities exist in database first
|
||||||
|
with make_session() as session:
|
||||||
|
create_or_update_user(session, message.author)
|
||||||
|
create_or_update_channel(session, message.channel)
|
||||||
|
if message.guild:
|
||||||
|
create_or_update_server(session, message.guild)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Queue the message for processing
|
||||||
|
add_discord_message.delay(
|
||||||
|
message_id=message.id,
|
||||||
|
channel_id=message.channel.id,
|
||||||
|
author_id=message.author.id,
|
||||||
|
server_id=message.guild.id if message.guild else None,
|
||||||
|
content=message.content or "",
|
||||||
|
sent_at=message.created_at.isoformat(),
|
||||||
|
message_reference_id=message.reference.message_id
|
||||||
|
if message.reference
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error queuing message {message.id}: {e}")
|
||||||
|
|
||||||
|
async def on_message_edit(self, before: discord.Message, after: discord.Message):
|
||||||
|
"""Queue message edit for database update"""
|
||||||
|
try:
|
||||||
|
edit_time = after.edited_at or datetime.now(timezone.utc)
|
||||||
|
edit_discord_message.delay(
|
||||||
|
message_id=after.id,
|
||||||
|
content=after.content,
|
||||||
|
edited_at=edit_time.isoformat(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error queuing message edit {after.id}: {e}")
|
||||||
|
|
||||||
|
async def sync_servers_and_channels(self):
|
||||||
|
"""Sync server and channel metadata on startup"""
|
||||||
|
for guild in self.guilds:
|
||||||
|
sync_guild_metadata(guild)
|
||||||
|
|
||||||
|
logger.info(f"Synced {len(self.guilds)} servers and their channels")
|
||||||
|
|
||||||
|
async def refresh_metadata(self) -> dict[str, int]:
|
||||||
|
"""Refresh server and channel metadata from Discord and update database"""
|
||||||
|
print("🔄 Refreshing Discord metadata...")
|
||||||
|
|
||||||
|
servers_updated = 0
|
||||||
|
channels_updated = 0
|
||||||
|
users_updated = 0
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Refresh all servers
|
||||||
|
for guild in self.guilds:
|
||||||
|
create_or_update_server(session, guild)
|
||||||
|
servers_updated += 1
|
||||||
|
|
||||||
|
# Refresh all channels in this server
|
||||||
|
for channel in guild.channels:
|
||||||
|
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||||
|
create_or_update_channel(session, channel)
|
||||||
|
channels_updated += 1
|
||||||
|
|
||||||
|
# Refresh all members in this server (if members intent is enabled)
|
||||||
|
if self.intents.members:
|
||||||
|
for member in guild.members:
|
||||||
|
create_or_update_user(session, member)
|
||||||
|
users_updated += 1
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"servers_updated": servers_updated,
|
||||||
|
"channels_updated": channels_updated,
|
||||||
|
"users_updated": users_updated,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"✅ Metadata refresh complete: {result}")
|
||||||
|
logger.info(f"Metadata refresh complete: {result}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_user(self, user_identifier: int | str) -> discord.User | None:
|
||||||
|
"""Get a Discord user by ID or username"""
|
||||||
|
if isinstance(user_identifier, int):
|
||||||
|
# Direct user ID lookup
|
||||||
|
if user := super().get_user(user_identifier):
|
||||||
|
return user
|
||||||
|
try:
|
||||||
|
return await self.fetch_user(user_identifier)
|
||||||
|
except discord.NotFound:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# Username lookup - search through all guilds
|
||||||
|
for guild in self.guilds:
|
||||||
|
for member in guild.members:
|
||||||
|
if (
|
||||||
|
member.name == user_identifier
|
||||||
|
or member.display_name == user_identifier
|
||||||
|
or f"{member.name}#{member.discriminator}" == user_identifier
|
||||||
|
):
|
||||||
|
return member
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_channel_by_name(
|
||||||
|
self, channel_name: str
|
||||||
|
) -> discord.TextChannel | None:
|
||||||
|
"""Get a Discord channel by name (does not create if missing)"""
|
||||||
|
# Search all guilds for the channel
|
||||||
|
for guild in self.guilds:
|
||||||
|
for ch in guild.channels:
|
||||||
|
if isinstance(ch, discord.TextChannel) and ch.name == channel_name:
|
||||||
|
return ch
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_channel(
|
||||||
|
self, channel_name: str, guild_id: int | None = None
|
||||||
|
) -> discord.TextChannel | None:
|
||||||
|
"""Create a Discord channel in the specified guild (or first guild if none specified)"""
|
||||||
|
target_guild = None
|
||||||
|
|
||||||
|
if guild_id:
|
||||||
|
target_guild = self.get_guild(guild_id)
|
||||||
|
elif self.guilds:
|
||||||
|
target_guild = self.guilds[0] # Default to first guild
|
||||||
|
|
||||||
|
if not target_guild:
|
||||||
|
logger.error(f"No guild available to create channel {channel_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel = await target_guild.create_text_channel(channel_name)
|
||||||
|
logger.info(f"Created channel {channel_name} in {target_guild.name}")
|
||||||
|
return channel
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create channel {channel_name} in {target_guild.name}: {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send_dm(self, user_identifier: int | str, message: str) -> bool:
|
||||||
|
"""Send a DM using this collector's Discord client"""
|
||||||
|
try:
|
||||||
|
user = await self.get_user(user_identifier)
|
||||||
|
if not user:
|
||||||
|
logger.error(f"User {user_identifier} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await user.send(message)
|
||||||
|
logger.info(f"Sent DM to {user_identifier}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
||||||
|
"""Send a message to a channel by name across all guilds"""
|
||||||
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel = await self.get_channel_by_name(channel_name)
|
||||||
|
if not channel:
|
||||||
|
logger.error(f"Channel {channel_name} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await channel.send(message)
|
||||||
|
logger.info(f"Sent message to channel {channel_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def run_collector():
|
||||||
|
"""Run the Discord message collector"""
|
||||||
|
if not settings.DISCORD_BOT_TOKEN:
|
||||||
|
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await collector.start(settings.DISCORD_BOT_TOKEN)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Discord collector failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run_collector())
|
||||||
@ -6,6 +6,7 @@ from memory.workers.tasks import (
|
|||||||
email,
|
email,
|
||||||
comic,
|
comic,
|
||||||
blogs,
|
blogs,
|
||||||
|
discord,
|
||||||
ebook,
|
ebook,
|
||||||
forums,
|
forums,
|
||||||
maintenance,
|
maintenance,
|
||||||
@ -20,6 +21,7 @@ __all__ = [
|
|||||||
"comic",
|
"comic",
|
||||||
"blogs",
|
"blogs",
|
||||||
"ebook",
|
"ebook",
|
||||||
|
"discord",
|
||||||
"forums",
|
"forums",
|
||||||
"maintenance",
|
"maintenance",
|
||||||
"notes",
|
"notes",
|
||||||
|
|||||||
@ -115,6 +115,9 @@ def sync_article_feed(feed_id: int) -> dict:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for feed_item in parser.parse_feed():
|
for feed_item in parser.parse_feed():
|
||||||
|
if not feed_item.url:
|
||||||
|
continue
|
||||||
|
|
||||||
articles_found += 1
|
articles_found += 1
|
||||||
|
|
||||||
existing = check_content_exists(session, BlogPost, url=feed_item.url)
|
existing = check_content_exists(session, BlogPost, url=feed_item.url)
|
||||||
|
|||||||
@ -10,9 +10,9 @@ from collections import defaultdict
|
|||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Iterable, Sequence, cast
|
from typing import Any, Callable, Sequence, cast
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, settings
|
from memory.common import embedding, qdrant
|
||||||
from memory.common.db.models import SourceItem, Chunk
|
from memory.common.db.models import SourceItem, Chunk
|
||||||
from memory.common.discord import notify_task_failure
|
from memory.common.discord import notify_task_failure
|
||||||
|
|
||||||
@ -38,19 +38,12 @@ def check_content_exists(
|
|||||||
Returns:
|
Returns:
|
||||||
Existing SourceItem if found, None otherwise
|
Existing SourceItem if found, None otherwise
|
||||||
"""
|
"""
|
||||||
|
query = session.query(model_class)
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if not hasattr(model_class, key):
|
if hasattr(model_class, key):
|
||||||
continue
|
query = query.filter(getattr(model_class, key) == value)
|
||||||
|
|
||||||
existing = (
|
return query.first()
|
||||||
session.query(model_class)
|
|
||||||
.filter(getattr(model_class, key) == value)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_content_hash(content: str, *additional_data: str) -> bytes:
|
def create_content_hash(content: str, *additional_data: str) -> bytes:
|
||||||
@ -286,6 +279,6 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
|||||||
traceback_str=traceback_str,
|
traceback_str=traceback_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "error", "error": str(e)}
|
return {"status": "error", "error": str(e), "traceback": traceback_str}
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
166
src/memory/workers/tasks/discord.py
Normal file
166
src/memory/workers/tasks/discord.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Celery tasks for Discord message processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from memory.common.celery_app import app
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import DiscordMessage, DiscordUser
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
safe_task_execution,
|
||||||
|
check_content_exists,
|
||||||
|
create_task_result,
|
||||||
|
process_content_item,
|
||||||
|
)
|
||||||
|
from memory.common.celery_app import (
|
||||||
|
ADD_DISCORD_MESSAGE,
|
||||||
|
EDIT_DISCORD_MESSAGE,
|
||||||
|
PROCESS_DISCORD_MESSAGE,
|
||||||
|
)
|
||||||
|
from memory.common import settings
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prev(
|
||||||
|
session: Session | scoped_session, channel_id: int, sent_at: datetime
|
||||||
|
) -> list[str]:
|
||||||
|
prev = (
|
||||||
|
session.query(DiscordUser.username, DiscordMessage.content)
|
||||||
|
.join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
|
||||||
|
.filter(
|
||||||
|
DiscordMessage.channel_id == channel_id,
|
||||||
|
DiscordMessage.sent_at < sent_at,
|
||||||
|
)
|
||||||
|
.order_by(DiscordMessage.sent_at.desc())
|
||||||
|
.limit(settings.DISCORD_CONTEXT_WINDOW)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||||
|
|
||||||
|
|
||||||
|
def should_process(message: DiscordMessage) -> bool:
|
||||||
|
return (
|
||||||
|
settings.DISCORD_PROCESS_MESSAGES
|
||||||
|
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
||||||
|
and not (
|
||||||
|
(message.server and message.server.ignore_messages)
|
||||||
|
or (message.channel and message.channel.ignore_messages)
|
||||||
|
or (message.discord_user and message.discord_user.ignore_messages)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||||
|
@safe_task_execution
|
||||||
|
def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||||
|
logger.info(f"Processing Discord message {message_id}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
discord_message = session.query(DiscordMessage).get(message_id)
|
||||||
|
if not discord_message:
|
||||||
|
logger.info(f"Discord message not found: {message_id}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Message not found",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Processing message", discord_message.id, discord_message.content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "processed",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=ADD_DISCORD_MESSAGE)
|
||||||
|
@safe_task_execution
|
||||||
|
def add_discord_message(
|
||||||
|
message_id: int,
|
||||||
|
channel_id: int,
|
||||||
|
author_id: int,
|
||||||
|
content: str,
|
||||||
|
sent_at: str,
|
||||||
|
server_id: int | None = None,
|
||||||
|
message_reference_id: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Add a Discord message to the database.
|
||||||
|
|
||||||
|
This task is queued by the Discord collector when messages are received.
|
||||||
|
"""
|
||||||
|
logger.info(f"Adding Discord message {message_id}: {content}")
|
||||||
|
# Include message_id in hash to ensure uniqueness across duplicate content
|
||||||
|
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
|
||||||
|
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
discord_message = DiscordMessage(
|
||||||
|
modality="text",
|
||||||
|
sha256=content_hash,
|
||||||
|
content=content,
|
||||||
|
channel_id=channel_id,
|
||||||
|
sent_at=sent_at_dt,
|
||||||
|
server_id=server_id,
|
||||||
|
discord_user_id=author_id,
|
||||||
|
message_id=message_id,
|
||||||
|
message_type="reply" if message_reference_id else "default",
|
||||||
|
reply_to_message_id=message_reference_id,
|
||||||
|
)
|
||||||
|
existing_msg = check_content_exists(
|
||||||
|
session, DiscordMessage, message_id=message_id, sha256=content_hash
|
||||||
|
)
|
||||||
|
if existing_msg:
|
||||||
|
logger.info(f"Discord message already exists: {existing_msg.message_id}")
|
||||||
|
return create_task_result(
|
||||||
|
existing_msg, "already_exists", message_id=message_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
|
||||||
|
|
||||||
|
result = process_content_item(discord_message, session)
|
||||||
|
if should_process(discord_message):
|
||||||
|
process_discord_message.delay(discord_message.id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=EDIT_DISCORD_MESSAGE)
|
||||||
|
@safe_task_execution
|
||||||
|
def edit_discord_message(
|
||||||
|
message_id: int, content: str, edited_at: str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Edit a Discord message in the database.
|
||||||
|
|
||||||
|
This task is queued by the Discord collector when messages are edited.
|
||||||
|
"""
|
||||||
|
logger.info(f"Editing Discord message {message_id}: {content}")
|
||||||
|
with make_session() as session:
|
||||||
|
existing_msg = check_content_exists(
|
||||||
|
session, DiscordMessage, message_id=message_id
|
||||||
|
)
|
||||||
|
if not existing_msg:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Message not found",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
existing_msg.content = content # type: ignore
|
||||||
|
if existing_msg.channel_id:
|
||||||
|
existing_msg.messages_before = get_prev(
|
||||||
|
session, existing_msg.channel_id, existing_msg.sent_at
|
||||||
|
)
|
||||||
|
existing_msg.edited_at = datetime.fromisoformat(
|
||||||
|
edited_at.replace("Z", "+00:00")
|
||||||
|
)
|
||||||
|
|
||||||
|
return process_content_item(existing_msg, session)
|
||||||
@ -88,11 +88,10 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
|
|
||||||
# Make the LLM call
|
# Make the LLM call
|
||||||
if scheduled_call.model:
|
if scheduled_call.model:
|
||||||
response = llms.call(
|
response = llms.summarize(
|
||||||
prompt=cast(str, scheduled_call.message),
|
prompt=cast(str, scheduled_call.message),
|
||||||
model=cast(str, scheduled_call.model),
|
model=cast(str, scheduled_call.model),
|
||||||
system_prompt=cast(str, scheduled_call.system_prompt)
|
system_prompt=cast(str, scheduled_call.system_prompt),
|
||||||
or llms.SYSTEM_PROMPT,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = cast(str, scheduled_call.message)
|
response = cast(str, scheduled_call.message)
|
||||||
|
|||||||
@ -273,6 +273,27 @@ def mock_anthropic_client():
|
|||||||
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
||||||
client = mock_client()
|
client = mock_client()
|
||||||
client.messages = Mock()
|
client.messages = Mock()
|
||||||
|
|
||||||
|
# Mock stream as a context manager
|
||||||
|
mock_stream = Mock()
|
||||||
|
mock_stream.__enter__ = Mock(
|
||||||
|
return_value=Mock(
|
||||||
|
__iter__=lambda self: iter(
|
||||||
|
[
|
||||||
|
Mock(
|
||||||
|
type="content_block_delta",
|
||||||
|
delta=Mock(
|
||||||
|
type="text_delta",
|
||||||
|
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_stream.__exit__ = Mock(return_value=False)
|
||||||
|
client.messages.stream = Mock(return_value=mock_stream)
|
||||||
|
|
||||||
client.messages.create = Mock(
|
client.messages.create = Mock(
|
||||||
return_value=Mock(
|
return_value=Mock(
|
||||||
content=[
|
content=[
|
||||||
|
|||||||
@ -2,318 +2,250 @@ import pytest
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from memory.common import discord, settings
|
from memory.common import discord
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_session_request():
|
def mock_api_url():
|
||||||
with patch("requests.Session.request") as mock:
|
"""Mock the API URL to avoid using actual settings"""
|
||||||
yield mock
|
with patch(
|
||||||
|
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
|
||||||
def mock_get_channels_response():
|
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
|
||||||
return [
|
def test_get_api_url():
|
||||||
{"name": "memory-errors", "id": "error_channel_id"},
|
"""Test API URL construction"""
|
||||||
{"name": "memory-activity", "id": "activity_channel_id"},
|
assert discord.get_api_url() == "http://testhost:9999"
|
||||||
{"name": "memory-discoveries", "id": "discovery_channel_id"},
|
|
||||||
{"name": "memory-chat", "id": "chat_channel_id"},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_discord_server_init(mock_session_request, mock_get_channels_response):
|
@patch("requests.post")
|
||||||
# Mock the channels API call
|
def test_send_dm_success(mock_post, mock_api_url):
|
||||||
|
"""Test successful DM sending"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = mock_get_channels_response
|
mock_response.json.return_value = {"success": True}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_session_request.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
server = discord.DiscordServer("server123", "Test Server")
|
result = discord.send_dm("user123", "Hello!")
|
||||||
|
|
||||||
assert server.server_id == "server123"
|
assert result is True
|
||||||
assert server.server_name == "Test Server"
|
mock_post.assert_called_once_with(
|
||||||
assert hasattr(server, "channels")
|
"http://localhost:8000/send_dm",
|
||||||
|
json={"user_identifier": "user123", "message": "Hello!"},
|
||||||
|
timeout=10,
|
||||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
|
|
||||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
|
||||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
|
||||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
|
|
||||||
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
|
|
||||||
# Mock the channels API call
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.json.return_value = mock_get_channels_response
|
|
||||||
mock_response.raise_for_status.return_value = None
|
|
||||||
mock_session_request.return_value = mock_response
|
|
||||||
|
|
||||||
server = discord.DiscordServer("server123", "Test Server")
|
|
||||||
|
|
||||||
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
|
|
||||||
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
|
|
||||||
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
|
|
||||||
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
|
|
||||||
def test_setup_channels_create_missing(mock_session_request):
|
|
||||||
# Mock get channels (empty) and create channel calls
|
|
||||||
get_response = Mock()
|
|
||||||
get_response.json.return_value = []
|
|
||||||
get_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
create_response = Mock()
|
|
||||||
create_response.json.return_value = {"id": "new_channel_id"}
|
|
||||||
create_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
mock_session_request.side_effect = [
|
|
||||||
get_response,
|
|
||||||
create_response,
|
|
||||||
create_response,
|
|
||||||
create_response,
|
|
||||||
create_response,
|
|
||||||
]
|
|
||||||
|
|
||||||
server = discord.DiscordServer("server123", "Test Server")
|
|
||||||
|
|
||||||
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
|
|
||||||
|
|
||||||
|
|
||||||
def test_channel_properties():
|
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
|
||||||
server.channels = {
|
|
||||||
discord.ERROR_CHANNEL: "error_id",
|
|
||||||
discord.ACTIVITY_CHANNEL: "activity_id",
|
|
||||||
discord.DISCOVERY_CHANNEL: "discovery_id",
|
|
||||||
discord.CHAT_CHANNEL: "chat_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
assert server.error_channel == "error_id"
|
|
||||||
assert server.activity_channel == "activity_id"
|
|
||||||
assert server.discovery_channel == "discovery_id"
|
|
||||||
assert server.chat_channel == "chat_id"
|
|
||||||
|
|
||||||
|
|
||||||
def test_channel_id_exists():
|
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
|
||||||
server.channels = {"test-channel": "channel123"}
|
|
||||||
|
|
||||||
assert server.channel_id("test-channel") == "channel123"
|
|
||||||
|
|
||||||
|
|
||||||
def test_channel_id_not_found():
|
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
|
||||||
server.channels = {}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Channel nonexistent not found"):
|
|
||||||
server.channel_id("nonexistent")
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_message(mock_session_request):
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.raise_for_status.return_value = None
|
|
||||||
mock_session_request.return_value = mock_response
|
|
||||||
|
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
|
||||||
|
|
||||||
server.send_message("channel123", "Hello World")
|
|
||||||
|
|
||||||
mock_session_request.assert_called_with(
|
|
||||||
"POST",
|
|
||||||
"https://discord.com/api/v10/channels/channel123/messages",
|
|
||||||
data=None,
|
|
||||||
json={"content": "Hello World"},
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_channel(mock_session_request):
|
@patch("requests.post")
|
||||||
|
def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||||
|
"""Test DM sending when API returns failure"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"id": "new_channel_id"}
|
mock_response.json.return_value = {"success": False}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_session_request.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
result = discord.send_dm("user123", "Hello!")
|
||||||
server.server_id = "server123"
|
|
||||||
|
|
||||||
channel_id = server.create_channel("new-channel")
|
assert result is False
|
||||||
|
|
||||||
assert channel_id == "new_channel_id"
|
|
||||||
mock_session_request.assert_called_with(
|
|
||||||
"POST",
|
|
||||||
"https://discord.com/api/v10/guilds/server123/channels",
|
|
||||||
data=None,
|
|
||||||
json={"name": "new-channel", "type": 0},
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_channel_custom_type(mock_session_request):
|
@patch("requests.post")
|
||||||
|
def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||||
|
"""Test DM sending when request raises exception"""
|
||||||
|
mock_post.side_effect = requests.RequestException("Network error")
|
||||||
|
|
||||||
|
result = discord.send_dm("user123", "Hello!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_send_dm_http_error(mock_post, mock_api_url):
|
||||||
|
"""Test DM sending when HTTP error occurs"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"id": "voice_channel_id"}
|
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.send_dm("user123", "Hello!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_broadcast_message_success(mock_post, mock_api_url):
|
||||||
|
"""Test successful channel message broadcast"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": True}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_session_request.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
result = discord.broadcast_message("general", "Announcement!")
|
||||||
server.server_id = "server123"
|
|
||||||
|
|
||||||
channel_id = server.create_channel("voice-channel", channel_type=2)
|
assert result is True
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
assert channel_id == "voice_channel_id"
|
"http://localhost:8000/send_channel",
|
||||||
mock_session_request.assert_called_with(
|
json={"channel_name": "general", "message": "Announcement!"},
|
||||||
"POST",
|
timeout=10,
|
||||||
"https://discord.com/api/v10/guilds/server123/channels",
|
|
||||||
data=None,
|
|
||||||
json={"name": "voice-channel", "type": 2},
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_str_representation():
|
@patch("requests.post")
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||||
server.server_id = "server123"
|
"""Test channel message broadcast failure"""
|
||||||
server.server_name = "Test Server"
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": False}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
|
result = discord.broadcast_message("general", "Announcement!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
|
@patch("requests.post")
|
||||||
def test_request_adds_headers(mock_session_request):
|
def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
"""Test channel message broadcast with exception"""
|
||||||
|
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||||
|
|
||||||
server.request("GET", "https://example.com", headers={"Custom": "header"})
|
result = discord.broadcast_message("general", "Announcement!")
|
||||||
|
|
||||||
expected_headers = {
|
assert result is False
|
||||||
"Custom": "header",
|
|
||||||
"Authorization": "Bot test_token_123",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
mock_session_request.assert_called_once_with(
|
|
||||||
"GET", "https://example.com", headers=expected_headers
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_channels_url():
|
|
||||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
|
||||||
server.server_id = "server123"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
|
||||||
@patch("requests.get")
|
@patch("requests.get")
|
||||||
def test_get_bot_servers_success(mock_get):
|
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||||
|
"""Test health check when collector is healthy"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = [
|
mock_response.json.return_value = {"status": "healthy"}
|
||||||
{"id": "server1", "name": "Server 1"},
|
|
||||||
{"id": "server2", "name": "Server 2"},
|
|
||||||
]
|
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
servers = discord.get_bot_servers()
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
assert len(servers) == 2
|
assert result is True
|
||||||
assert servers[0] == {"id": "server1", "name": "Server 1"}
|
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||||
mock_get.assert_called_once_with(
|
|
||||||
"https://discord.com/api/v10/users/@me/guilds",
|
|
||||||
headers={"Authorization": "Bot test_token"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
|
||||||
def test_get_bot_servers_no_token():
|
|
||||||
assert discord.get_bot_servers() == []
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
|
||||||
@patch("requests.get")
|
@patch("requests.get")
|
||||||
def test_get_bot_servers_exception(mock_get):
|
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||||
mock_get.side_effect = requests.RequestException("API Error")
|
"""Test health check when collector returns unhealthy status"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"status": "unhealthy"}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
servers = discord.get_bot_servers()
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
assert servers == []
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.get_bot_servers")
|
@patch("requests.get")
|
||||||
@patch("memory.common.discord.DiscordServer")
|
def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||||
def test_load_servers(mock_discord_server_class, mock_get_servers):
|
"""Test health check when request fails"""
|
||||||
mock_get_servers.return_value = [
|
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||||
{"id": "server1", "name": "Server 1"},
|
|
||||||
{"id": "server2", "name": "Server 2"},
|
|
||||||
]
|
|
||||||
|
|
||||||
discord.load_servers()
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
assert mock_discord_server_class.call_count == 2
|
assert result is False
|
||||||
mock_discord_server_class.assert_any_call("server1", "Server 1")
|
|
||||||
mock_discord_server_class.assert_any_call("server2", "Server 2")
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
@patch("requests.post")
|
||||||
def test_broadcast_message():
|
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
|
||||||
mock_server1 = Mock()
|
"""Test successful metadata refresh"""
|
||||||
mock_server2 = Mock()
|
mock_response = Mock()
|
||||||
discord.servers = {"1": mock_server1, "2": mock_server2}
|
mock_response.json.return_value = {
|
||||||
|
"servers": 5,
|
||||||
|
"channels": 20,
|
||||||
|
"users": 100,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
discord.broadcast_message("test-channel", "Hello")
|
result = discord.refresh_discord_metadata()
|
||||||
|
|
||||||
mock_server1.send_message.assert_called_once_with(
|
assert result == {"servers": 5, "channels": 20, "users": 100}
|
||||||
mock_server1.channel_id.return_value, "Hello"
|
mock_post.assert_called_once_with(
|
||||||
)
|
"http://localhost:8000/refresh_metadata", timeout=30
|
||||||
mock_server2.send_message.assert_called_once_with(
|
|
||||||
mock_server2.channel_id.return_value, "Hello"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
@patch("requests.post")
|
||||||
def test_broadcast_message_disabled():
|
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
|
||||||
mock_server = Mock()
|
"""Test metadata refresh failure"""
|
||||||
discord.servers = {"1": mock_server}
|
mock_post.side_effect = requests.RequestException("Failed to connect")
|
||||||
|
|
||||||
discord.broadcast_message("test-channel", "Hello")
|
result = discord.refresh_discord_metadata()
|
||||||
|
|
||||||
mock_server.send_message.assert_not_called()
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
|
||||||
|
"""Test metadata refresh with HTTP error"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.refresh_discord_metadata()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
|
||||||
def test_send_error_message(mock_broadcast):
|
def test_send_error_message(mock_broadcast):
|
||||||
discord.send_error_message("Error occurred")
|
"""Test sending error message to error channel"""
|
||||||
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_error_message("Something broke")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
|
||||||
def test_send_activity_message(mock_broadcast):
|
def test_send_activity_message(mock_broadcast):
|
||||||
discord.send_activity_message("Activity update")
|
"""Test sending activity message to activity channel"""
|
||||||
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_activity_message("User logged in")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
|
||||||
def test_send_discovery_message(mock_broadcast):
|
def test_send_discovery_message(mock_broadcast):
|
||||||
discord.send_discovery_message("Discovery made")
|
"""Test sending discovery message to discovery channel"""
|
||||||
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_discovery_message("Found interesting pattern")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
|
||||||
def test_send_chat_message(mock_broadcast):
|
def test_send_chat_message(mock_broadcast):
|
||||||
discord.send_chat_message("Chat message")
|
"""Test sending chat message to chat channel"""
|
||||||
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_chat_message("Hello from bot")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_basic(mock_send_error):
|
def test_notify_task_failure_basic(mock_send_error):
|
||||||
|
"""Test basic task failure notification"""
|
||||||
discord.notify_task_failure("test_task", "Something went wrong")
|
discord.notify_task_failure("test_task", "Something went wrong")
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
@ -323,69 +255,181 @@ def test_notify_task_failure_basic(mock_send_error):
|
|||||||
assert "**Error:** Something went wrong" in message
|
assert "**Error:** Something went wrong" in message
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_with_args(mock_send_error):
|
def test_notify_task_failure_with_args(mock_send_error):
|
||||||
|
"""Test task failure notification with arguments"""
|
||||||
discord.notify_task_failure(
|
discord.notify_task_failure(
|
||||||
"test_task",
|
"test_task",
|
||||||
"Error message",
|
"Error occurred",
|
||||||
task_args=("arg1", "arg2"),
|
task_args=("arg1", 42),
|
||||||
task_kwargs={"key": "value"},
|
task_kwargs={"key": "value", "number": 123},
|
||||||
)
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
assert "**Args:** `('arg1', 'arg2')`" in message
|
assert "**Args:** `('arg1', 42)" in message
|
||||||
assert "**Kwargs:** `{'key': 'value'}`" in message
|
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_with_traceback(mock_send_error):
|
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||||
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
|
"""Test task failure notification with traceback"""
|
||||||
|
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
|
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
assert "**Traceback:**" in message
|
assert "**Traceback:**" in message
|
||||||
assert traceback in message
|
assert "Exception: test" in message
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||||
long_error = "x" * 600 # Longer than 500 char limit
|
"""Test that long error messages are truncated"""
|
||||||
|
long_error = "x" * 600
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", long_error)
|
discord.notify_task_failure("test_task", long_error)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][0]
|
||||||
assert long_error[:500] in message
|
|
||||||
|
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||||
|
assert "**Error:** " + long_error[:500] in message
|
||||||
|
# The full 600-char error should not be present
|
||||||
|
error_section = message.split("**Error:** ")[1].split("\n")[0]
|
||||||
|
assert len(error_section) == 500
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||||
long_traceback = "x" * 1000 # Longer than 800 char limit
|
"""Test that long tracebacks are truncated"""
|
||||||
|
long_traceback = "x" * 1000
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
# Traceback should show last 800 chars
|
||||||
assert long_traceback[-800:] in message
|
assert long_traceback[-800:] in message
|
||||||
|
# The full 1000-char traceback should not be present
|
||||||
|
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
|
||||||
|
assert len(traceback_section) == 800
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||||
|
"""Test that long task arguments are truncated"""
|
||||||
|
long_args = ("x" * 300,)
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
# Args should be truncated to 200 chars
|
||||||
|
assert (
|
||||||
|
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
|
||||||
|
) # Some buffer for formatting
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||||
|
"""Test that long task kwargs are truncated"""
|
||||||
|
long_kwargs = {"key": "x" * 300}
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
# Kwargs should be truncated to 200 chars
|
||||||
|
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
def test_notify_task_failure_disabled(mock_send_error):
|
def test_notify_task_failure_disabled(mock_send_error):
|
||||||
discord.notify_task_failure("test_task", "Error message")
|
"""Test that notifications are not sent when disabled"""
|
||||||
|
discord.notify_task_failure("test_task", "Error occurred")
|
||||||
|
|
||||||
mock_send_error.assert_not_called()
|
mock_send_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
def test_notify_task_failure_send_fails(mock_send_error):
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
mock_send_error.side_effect = Exception("Discord API error")
|
def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||||
|
"""Test that exceptions in send_error_message don't propagate"""
|
||||||
|
mock_send_error.side_effect = Exception("Failed to send")
|
||||||
|
|
||||||
# Should not raise, just log the error
|
# Should not raise
|
||||||
discord.notify_task_failure("test_task", "Error message")
|
discord.notify_task_failure("test_task", "Error occurred")
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"function,channel_setting,message",
|
||||||
|
[
|
||||||
|
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
|
||||||
|
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
|
||||||
|
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
|
||||||
|
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
def test_convenience_functions_use_correct_channels(
|
||||||
|
mock_broadcast, function, channel_setting, message
|
||||||
|
):
|
||||||
|
"""Test that convenience functions use the correct channel settings"""
|
||||||
|
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||||
|
function(message)
|
||||||
|
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||||
|
"""Test sending DM with special characters"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": True}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||||
|
result = discord.send_dm("user123", message_with_special_chars)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||||
|
"""Test broadcasting a long message"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": True}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
long_message = "A" * 2000
|
||||||
|
result = discord.broadcast_message("general", long_message)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert call_args[1]["json"]["message"] == long_message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||||
|
"""Test health check when response doesn't have status key"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|||||||
435
tests/memory/common/test_discord_integration.py
Normal file
435
tests/memory/common/test_discord_integration.py
Normal file
@ -0,0 +1,435 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from memory.common import discord
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_api_url():
|
||||||
|
"""Mock the API URL to avoid using actual settings"""
|
||||||
|
with patch(
|
||||||
|
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
|
||||||
|
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
|
||||||
|
def test_get_api_url():
|
||||||
|
"""Test API URL construction"""
|
||||||
|
assert discord.get_api_url() == "http://testhost:9999"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_send_dm_success(mock_post, mock_api_url):
|
||||||
|
"""Test successful DM sending"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": True}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.send_dm("user123", "Hello!")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"http://localhost:8000/send_dm",
|
||||||
|
json={"user_identifier": "user123", "message": "Hello!"},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||||
|
"""Test DM sending when API returns failure"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": False}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.send_dm("user123", "Hello!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||||
|
"""Test DM sending when request raises exception"""
|
||||||
|
mock_post.side_effect = requests.RequestException("Network error")
|
||||||
|
|
||||||
|
result = discord.send_dm("user123", "Hello!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_send_dm_http_error(mock_post, mock_api_url):
|
||||||
|
"""Test DM sending when HTTP error occurs"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.send_dm("user123", "Hello!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_broadcast_message_success(mock_post, mock_api_url):
|
||||||
|
"""Test successful channel message broadcast"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": True}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.broadcast_message("general", "Announcement!")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"http://localhost:8000/send_channel",
|
||||||
|
json={"channel_name": "general", "message": "Announcement!"},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||||
|
"""Test channel message broadcast failure"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": False}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.broadcast_message("general", "Announcement!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||||
|
"""Test channel message broadcast with exception"""
|
||||||
|
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||||
|
|
||||||
|
result = discord.broadcast_message("general", "Announcement!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||||
|
"""Test health check when collector is healthy"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"status": "healthy"}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||||
|
"""Test health check when collector returns unhealthy status"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"status": "unhealthy"}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||||
|
"""Test health check when request fails"""
|
||||||
|
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||||
|
|
||||||
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
|
||||||
|
"""Test successful metadata refresh"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"servers": 5,
|
||||||
|
"channels": 20,
|
||||||
|
"users": 100,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.refresh_discord_metadata()
|
||||||
|
|
||||||
|
assert result == {"servers": 5, "channels": 20, "users": 100}
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"http://localhost:8000/refresh_metadata", timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
|
||||||
|
"""Test metadata refresh failure"""
|
||||||
|
mock_post.side_effect = requests.RequestException("Failed to connect")
|
||||||
|
|
||||||
|
result = discord.refresh_discord_metadata()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
|
||||||
|
"""Test metadata refresh with HTTP error"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.refresh_discord_metadata()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
|
||||||
|
def test_send_error_message(mock_broadcast):
|
||||||
|
"""Test sending error message to error channel"""
|
||||||
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_error_message("Something broke")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
|
||||||
|
def test_send_activity_message(mock_broadcast):
|
||||||
|
"""Test sending activity message to activity channel"""
|
||||||
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_activity_message("User logged in")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
|
||||||
|
def test_send_discovery_message(mock_broadcast):
|
||||||
|
"""Test sending discovery message to discovery channel"""
|
||||||
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_discovery_message("Found interesting pattern")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
|
||||||
|
def test_send_chat_message(mock_broadcast):
|
||||||
|
"""Test sending chat message to chat channel"""
|
||||||
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
|
result = discord.send_chat_message("Hello from bot")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_basic(mock_send_error):
|
||||||
|
"""Test basic task failure notification"""
|
||||||
|
discord.notify_task_failure("test_task", "Something went wrong")
|
||||||
|
|
||||||
|
mock_send_error.assert_called_once()
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
assert "🚨 **Task Failed: test_task**" in message
|
||||||
|
assert "**Error:** Something went wrong" in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_with_args(mock_send_error):
|
||||||
|
"""Test task failure notification with arguments"""
|
||||||
|
discord.notify_task_failure(
|
||||||
|
"test_task",
|
||||||
|
"Error occurred",
|
||||||
|
task_args=("arg1", 42),
|
||||||
|
task_kwargs={"key": "value", "number": 123},
|
||||||
|
)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
assert "**Args:** `('arg1', 42)" in message
|
||||||
|
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||||
|
"""Test task failure notification with traceback"""
|
||||||
|
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
assert "**Traceback:**" in message
|
||||||
|
assert "Exception: test" in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||||
|
"""Test that long error messages are truncated"""
|
||||||
|
long_error = "x" * 600
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", long_error)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||||
|
assert "**Error:** " + long_error[:500] in message
|
||||||
|
# The full 600-char error should not be present
|
||||||
|
error_section = message.split("**Error:** ")[1].split("\n")[0]
|
||||||
|
assert len(error_section) == 500
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||||
|
"""Test that long tracebacks are truncated"""
|
||||||
|
long_traceback = "x" * 1000
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
# Traceback should show last 800 chars
|
||||||
|
assert long_traceback[-800:] in message
|
||||||
|
# The full 1000-char traceback should not be present
|
||||||
|
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
|
||||||
|
assert len(traceback_section) == 800
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||||
|
"""Test that long task arguments are truncated"""
|
||||||
|
long_args = ("x" * 300,)
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
# Args should be truncated to 200 chars
|
||||||
|
assert (
|
||||||
|
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
|
||||||
|
) # Some buffer for formatting
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||||
|
"""Test that long task kwargs are truncated"""
|
||||||
|
long_kwargs = {"key": "x" * 300}
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
# Kwargs should be truncated to 200 chars
|
||||||
|
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
|
def test_notify_task_failure_disabled(mock_send_error):
|
||||||
|
"""Test that notifications are not sent when disabled"""
|
||||||
|
discord.notify_task_failure("test_task", "Error occurred")
|
||||||
|
|
||||||
|
mock_send_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||||
|
"""Test that exceptions in send_error_message don't propagate"""
|
||||||
|
mock_send_error.side_effect = Exception("Failed to send")
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
discord.notify_task_failure("test_task", "Error occurred")
|
||||||
|
|
||||||
|
mock_send_error.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"function,channel_setting,message",
|
||||||
|
[
|
||||||
|
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
|
||||||
|
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
|
||||||
|
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
|
||||||
|
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
def test_convenience_functions_use_correct_channels(
|
||||||
|
mock_broadcast, function, channel_setting, message
|
||||||
|
):
|
||||||
|
"""Test that convenience functions use the correct channel settings"""
|
||||||
|
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||||
|
function(message)
|
||||||
|
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||||
|
"""Test sending DM with special characters"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": True}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||||
|
result = discord.send_dm("user123", message_with_special_chars)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||||
|
"""Test broadcasting a long message"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"success": True}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
long_message = "A" * 2000
|
||||||
|
result = discord.broadcast_message("general", long_message)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert call_args[1]["json"]["message"] == long_message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||||
|
"""Test health check when response doesn't have status key"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
result = discord.is_collector_healthy()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
840
tests/memory/discord/test_collector.py
Normal file
840
tests/memory/discord/test_collector.py
Normal file
@ -0,0 +1,840 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from memory.discord.collector import (
|
||||||
|
create_or_update_server,
|
||||||
|
determine_channel_metadata,
|
||||||
|
create_or_update_channel,
|
||||||
|
create_or_update_user,
|
||||||
|
determine_message_metadata,
|
||||||
|
should_track_message,
|
||||||
|
should_collect_bot_message,
|
||||||
|
sync_guild_metadata,
|
||||||
|
MessageCollector,
|
||||||
|
)
|
||||||
|
from memory.common.db.models.sources import (
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for Discord objects
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_guild():
|
||||||
|
"""Mock Discord Guild object"""
|
||||||
|
guild = Mock(spec=discord.Guild)
|
||||||
|
guild.id = 123456789
|
||||||
|
guild.name = "Test Server"
|
||||||
|
guild.description = "A test server"
|
||||||
|
guild.member_count = 42
|
||||||
|
return guild
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_text_channel():
|
||||||
|
"""Mock Discord TextChannel object"""
|
||||||
|
channel = Mock(spec=discord.TextChannel)
|
||||||
|
channel.id = 987654321
|
||||||
|
channel.name = "general"
|
||||||
|
guild = Mock()
|
||||||
|
guild.id = 123456789
|
||||||
|
channel.guild = guild
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dm_channel():
|
||||||
|
"""Mock Discord DMChannel object"""
|
||||||
|
channel = Mock(spec=discord.DMChannel)
|
||||||
|
channel.id = 111222333
|
||||||
|
recipient = Mock()
|
||||||
|
recipient.name = "TestUser"
|
||||||
|
channel.recipient = recipient
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user():
|
||||||
|
"""Mock Discord User object"""
|
||||||
|
user = Mock(spec=discord.User)
|
||||||
|
user.id = 444555666
|
||||||
|
user.name = "testuser"
|
||||||
|
user.display_name = "Test User"
|
||||||
|
user.bot = False
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_message(mock_text_channel, mock_user):
|
||||||
|
"""Mock Discord Message object"""
|
||||||
|
message = Mock(spec=discord.Message)
|
||||||
|
message.id = 777888999
|
||||||
|
message.channel = mock_text_channel
|
||||||
|
message.author = mock_user
|
||||||
|
message.guild = mock_text_channel.guild
|
||||||
|
message.content = "Test message"
|
||||||
|
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
|
message.reference = None
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for create_or_update_server
|
||||||
|
def test_create_or_update_server_creates_new(db_session, mock_guild):
|
||||||
|
"""Test creating a new server record"""
|
||||||
|
result = create_or_update_server(db_session, mock_guild)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == mock_guild.id
|
||||||
|
assert result.name == mock_guild.name
|
||||||
|
assert result.description == mock_guild.description
|
||||||
|
assert result.member_count == mock_guild.member_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_or_update_server_updates_existing(db_session, mock_guild):
|
||||||
|
"""Test updating an existing server record"""
|
||||||
|
# Create initial server
|
||||||
|
server = DiscordServer(
|
||||||
|
id=mock_guild.id,
|
||||||
|
name="Old Name",
|
||||||
|
description="Old Description",
|
||||||
|
member_count=10,
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update with new data
|
||||||
|
mock_guild.name = "New Name"
|
||||||
|
mock_guild.description = "New Description"
|
||||||
|
mock_guild.member_count = 50
|
||||||
|
|
||||||
|
result = create_or_update_server(db_session, mock_guild)
|
||||||
|
|
||||||
|
assert result.name == "New Name"
|
||||||
|
assert result.description == "New Description"
|
||||||
|
assert result.member_count == 50
|
||||||
|
assert result.last_sync_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_or_update_server_none_guild(db_session):
|
||||||
|
"""Test with None guild"""
|
||||||
|
result = create_or_update_server(db_session, None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for determine_channel_metadata
|
||||||
|
def test_determine_channel_metadata_dm():
|
||||||
|
"""Test metadata for DM channel"""
|
||||||
|
channel = Mock(spec=discord.DMChannel)
|
||||||
|
channel.recipient = Mock()
|
||||||
|
channel.recipient.name = "TestUser"
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert channel_type == "dm"
|
||||||
|
assert server_id is None
|
||||||
|
assert "DM with TestUser" in name
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_channel_metadata_dm_no_recipient():
|
||||||
|
"""Test metadata for DM channel without recipient"""
|
||||||
|
channel = Mock(spec=discord.DMChannel)
|
||||||
|
channel.recipient = None
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert channel_type == "dm"
|
||||||
|
assert name == "Unknown DM"
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_channel_metadata_group_dm():
|
||||||
|
"""Test metadata for group DM channel"""
|
||||||
|
channel = Mock(spec=discord.GroupChannel)
|
||||||
|
channel.name = "Group Chat"
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert channel_type == "group_dm"
|
||||||
|
assert server_id is None
|
||||||
|
assert name == "Group Chat"
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_channel_metadata_group_dm_no_name():
|
||||||
|
"""Test metadata for group DM without name"""
|
||||||
|
channel = Mock(spec=discord.GroupChannel)
|
||||||
|
channel.name = None
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert name == "Group DM"
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_channel_metadata_text_channel():
|
||||||
|
"""Test metadata for text channel"""
|
||||||
|
channel = Mock(spec=discord.TextChannel)
|
||||||
|
channel.name = "general"
|
||||||
|
channel.guild = Mock()
|
||||||
|
channel.guild.id = 123
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert channel_type == "text"
|
||||||
|
assert server_id == 123
|
||||||
|
assert name == "general"
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_channel_metadata_voice_channel():
|
||||||
|
"""Test metadata for voice channel"""
|
||||||
|
channel = Mock(spec=discord.VoiceChannel)
|
||||||
|
channel.name = "voice-chat"
|
||||||
|
channel.guild = Mock()
|
||||||
|
channel.guild.id = 456
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert channel_type == "voice"
|
||||||
|
assert server_id == 456
|
||||||
|
assert name == "voice-chat"
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_channel_metadata_thread():
|
||||||
|
"""Test metadata for thread"""
|
||||||
|
channel = Mock(spec=discord.Thread)
|
||||||
|
channel.name = "thread-1"
|
||||||
|
channel.guild = Mock()
|
||||||
|
channel.guild.id = 789
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert channel_type == "thread"
|
||||||
|
assert server_id == 789
|
||||||
|
assert name == "thread-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_channel_metadata_unknown():
|
||||||
|
"""Test metadata for unknown channel type"""
|
||||||
|
channel = Mock()
|
||||||
|
channel.id = 999
|
||||||
|
# Ensure the mock doesn't have a 'name' attribute
|
||||||
|
del channel.name
|
||||||
|
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
|
||||||
|
assert channel_type == "unknown"
|
||||||
|
assert name == "Unknown-999"
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for create_or_update_channel
|
||||||
|
def test_create_or_update_channel_creates_new(
|
||||||
|
db_session, mock_text_channel, mock_guild
|
||||||
|
):
|
||||||
|
"""Test creating a new channel record"""
|
||||||
|
# Create the server first to satisfy foreign key constraint
|
||||||
|
create_or_update_server(db_session, mock_guild)
|
||||||
|
|
||||||
|
result = create_or_update_channel(db_session, mock_text_channel)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == mock_text_channel.id
|
||||||
|
assert result.name == mock_text_channel.name
|
||||||
|
assert result.channel_type == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_or_update_channel_updates_existing(db_session, mock_text_channel):
|
||||||
|
"""Test updating an existing channel record"""
|
||||||
|
# Create initial channel
|
||||||
|
channel = DiscordChannel(
|
||||||
|
id=mock_text_channel.id,
|
||||||
|
name="old-name",
|
||||||
|
channel_type="text",
|
||||||
|
)
|
||||||
|
db_session.add(channel)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update with new name
|
||||||
|
mock_text_channel.name = "new-name"
|
||||||
|
|
||||||
|
result = create_or_update_channel(db_session, mock_text_channel)
|
||||||
|
|
||||||
|
assert result.name == "new-name"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_or_update_channel_none_channel(db_session):
|
||||||
|
"""Test with None channel"""
|
||||||
|
result = create_or_update_channel(db_session, None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for create_or_update_user
|
||||||
|
def test_create_or_update_user_creates_new(db_session, mock_user):
|
||||||
|
"""Test creating a new user record"""
|
||||||
|
result = create_or_update_user(db_session, mock_user)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == mock_user.id
|
||||||
|
assert result.username == mock_user.name
|
||||||
|
assert result.display_name == mock_user.display_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_or_update_user_updates_existing(db_session, mock_user):
|
||||||
|
"""Test updating an existing user record"""
|
||||||
|
# Create initial user
|
||||||
|
user = DiscordUser(
|
||||||
|
id=mock_user.id,
|
||||||
|
username="oldname",
|
||||||
|
display_name="Old Display Name",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update with new data
|
||||||
|
mock_user.name = "newname"
|
||||||
|
mock_user.display_name = "New Display Name"
|
||||||
|
|
||||||
|
result = create_or_update_user(db_session, mock_user)
|
||||||
|
|
||||||
|
assert result.username == "newname"
|
||||||
|
assert result.display_name == "New Display Name"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_or_update_user_none_user(db_session):
|
||||||
|
"""Test with None user"""
|
||||||
|
result = create_or_update_user(db_session, None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for determine_message_metadata
|
||||||
|
def test_determine_message_metadata_default():
|
||||||
|
"""Test metadata for default message"""
|
||||||
|
message = Mock()
|
||||||
|
message.reference = None
|
||||||
|
message.channel = Mock()
|
||||||
|
# Ensure channel doesn't have parent attribute
|
||||||
|
del message.channel.parent
|
||||||
|
|
||||||
|
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||||
|
|
||||||
|
assert message_type == "default"
|
||||||
|
assert reply_to_id is None
|
||||||
|
assert thread_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_message_metadata_reply():
|
||||||
|
"""Test metadata for reply message"""
|
||||||
|
message = Mock()
|
||||||
|
message.reference = Mock()
|
||||||
|
message.reference.message_id = 123456
|
||||||
|
message.channel = Mock()
|
||||||
|
|
||||||
|
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||||
|
|
||||||
|
assert message_type == "reply"
|
||||||
|
assert reply_to_id == 123456
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_message_metadata_thread():
|
||||||
|
"""Test metadata for message in thread"""
|
||||||
|
message = Mock()
|
||||||
|
message.reference = None
|
||||||
|
message.channel = Mock()
|
||||||
|
message.channel.id = 999
|
||||||
|
message.channel.parent = Mock() # Has parent means it's a thread
|
||||||
|
|
||||||
|
message_type, reply_to_id, thread_id = determine_message_metadata(message)
|
||||||
|
|
||||||
|
assert thread_id == 999
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for should_track_message
|
||||||
|
def test_should_track_message_server_disabled(db_session):
|
||||||
|
"""Test when server has tracking disabled"""
|
||||||
|
server = DiscordServer(id=1, name="Server", track_messages=False)
|
||||||
|
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
|
||||||
|
user = DiscordUser(id=3, username="User")
|
||||||
|
|
||||||
|
result = should_track_message(server, channel, user)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_track_message_channel_disabled(db_session):
|
||||||
|
"""Test when channel has tracking disabled"""
|
||||||
|
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||||
|
channel = DiscordChannel(
|
||||||
|
id=2, name="Channel", channel_type="text", track_messages=False
|
||||||
|
)
|
||||||
|
user = DiscordUser(id=3, username="User")
|
||||||
|
|
||||||
|
result = should_track_message(server, channel, user)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_track_message_dm_allowed(db_session):
|
||||||
|
"""Test DM tracking when user allows it"""
|
||||||
|
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||||
|
user = DiscordUser(id=3, username="User", track_messages=True)
|
||||||
|
|
||||||
|
result = should_track_message(None, channel, user)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_track_message_dm_not_allowed(db_session):
|
||||||
|
"""Test DM tracking when user doesn't allow it"""
|
||||||
|
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||||
|
user = DiscordUser(id=3, username="User", track_messages=False)
|
||||||
|
|
||||||
|
result = should_track_message(None, channel, user)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_track_message_default_true(db_session):
|
||||||
|
"""Test default tracking behavior"""
|
||||||
|
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||||
|
channel = DiscordChannel(
|
||||||
|
id=2, name="Channel", channel_type="text", track_messages=True
|
||||||
|
)
|
||||||
|
user = DiscordUser(id=3, username="User")
|
||||||
|
|
||||||
|
result = should_track_message(server, channel, user)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for should_collect_bot_message
|
||||||
|
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", False)
|
||||||
|
def test_should_collect_bot_message_bot_not_allowed():
|
||||||
|
"""Test bot message collection when disabled"""
|
||||||
|
message = Mock()
|
||||||
|
message.author = Mock()
|
||||||
|
message.author.bot = True
|
||||||
|
|
||||||
|
result = should_collect_bot_message(message)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", True)
|
||||||
|
def test_should_collect_bot_message_bot_allowed():
|
||||||
|
"""Test bot message collection when enabled"""
|
||||||
|
message = Mock()
|
||||||
|
message.author = Mock()
|
||||||
|
message.author.bot = True
|
||||||
|
|
||||||
|
result = should_collect_bot_message(message)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_collect_bot_message_human():
|
||||||
|
"""Test human message collection"""
|
||||||
|
message = Mock()
|
||||||
|
message.author = Mock()
|
||||||
|
message.author.bot = False
|
||||||
|
|
||||||
|
result = should_collect_bot_message(message)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for sync_guild_metadata
|
||||||
|
@patch("memory.discord.collector.make_session")
|
||||||
|
def test_sync_guild_metadata(mock_make_session, mock_guild):
|
||||||
|
"""Test syncing guild metadata"""
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock session.query().get() to return None (new server)
|
||||||
|
mock_session.query.return_value.get.return_value = None
|
||||||
|
|
||||||
|
# Mock channels
|
||||||
|
text_channel = Mock(spec=discord.TextChannel)
|
||||||
|
text_channel.id = 1
|
||||||
|
text_channel.name = "general"
|
||||||
|
text_channel.guild = mock_guild
|
||||||
|
|
||||||
|
voice_channel = Mock(spec=discord.VoiceChannel)
|
||||||
|
voice_channel.id = 2
|
||||||
|
voice_channel.name = "voice"
|
||||||
|
voice_channel.guild = mock_guild
|
||||||
|
|
||||||
|
mock_guild.channels = [text_channel, voice_channel]
|
||||||
|
|
||||||
|
sync_guild_metadata(mock_guild)
|
||||||
|
|
||||||
|
# Verify session.commit was called
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for MessageCollector class
|
||||||
|
def test_message_collector_init():
|
||||||
|
"""Test MessageCollector initialization"""
|
||||||
|
collector = MessageCollector()
|
||||||
|
|
||||||
|
assert collector.command_prefix == "!memory_"
|
||||||
|
assert collector.help_command is None
|
||||||
|
assert collector.intents.message_content is True
|
||||||
|
assert collector.intents.guilds is True
|
||||||
|
assert collector.intents.members is True
|
||||||
|
assert collector.intents.dm_messages is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_ready():
|
||||||
|
"""Test on_ready event handler"""
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.user = Mock()
|
||||||
|
collector.user.name = "TestBot"
|
||||||
|
collector.guilds = [Mock(), Mock()]
|
||||||
|
collector.sync_servers_and_channels = AsyncMock()
|
||||||
|
|
||||||
|
await collector.on_ready()
|
||||||
|
|
||||||
|
collector.sync_servers_and_channels.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.discord.collector.make_session")
|
||||||
|
@patch("memory.discord.collector.add_discord_message")
|
||||||
|
async def test_on_message_success(mock_add_task, mock_make_session, mock_message):
|
||||||
|
"""Test successful message handling"""
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||||
|
mock_session.query.return_value.get.return_value = None # New entities
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
await collector.on_message(mock_message)
|
||||||
|
|
||||||
|
# Verify task was queued
|
||||||
|
mock_add_task.delay.assert_called_once()
|
||||||
|
call_kwargs = mock_add_task.delay.call_args[1]
|
||||||
|
assert call_kwargs["message_id"] == mock_message.id
|
||||||
|
assert call_kwargs["channel_id"] == mock_message.channel.id
|
||||||
|
assert call_kwargs["author_id"] == mock_message.author.id
|
||||||
|
assert call_kwargs["content"] == mock_message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.discord.collector.make_session")
|
||||||
|
async def test_on_message_bot_message_filtered(mock_make_session, mock_message):
|
||||||
|
"""Test bot message filtering"""
|
||||||
|
mock_message.author.bot = True
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.discord.collector.should_collect_bot_message", return_value=False
|
||||||
|
):
|
||||||
|
collector = MessageCollector()
|
||||||
|
await collector.on_message(mock_message)
|
||||||
|
|
||||||
|
# Should not create session or queue task
|
||||||
|
mock_make_session.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.discord.collector.make_session")
|
||||||
|
async def test_on_message_error_handling(mock_make_session, mock_message):
|
||||||
|
"""Test error handling in on_message"""
|
||||||
|
mock_make_session.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
# Should not raise
|
||||||
|
await collector.on_message(mock_message)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.discord.collector.edit_discord_message")
|
||||||
|
async def test_on_message_edit(mock_edit_task):
|
||||||
|
"""Test message edit handler"""
|
||||||
|
before = Mock()
|
||||||
|
after = Mock()
|
||||||
|
after.id = 123
|
||||||
|
after.content = "Edited content"
|
||||||
|
after.edited_at = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
await collector.on_message_edit(before, after)
|
||||||
|
|
||||||
|
mock_edit_task.delay.assert_called_once()
|
||||||
|
call_kwargs = mock_edit_task.delay.call_args[1]
|
||||||
|
assert call_kwargs["message_id"] == 123
|
||||||
|
assert call_kwargs["content"] == "Edited content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_edit_error_handling():
|
||||||
|
"""Test error handling in on_message_edit"""
|
||||||
|
before = Mock()
|
||||||
|
after = Mock()
|
||||||
|
after.id = 123
|
||||||
|
after.content = "Edited"
|
||||||
|
after.edited_at = None # Will trigger datetime.now
|
||||||
|
|
||||||
|
with patch("memory.discord.collector.edit_discord_message") as mock_edit:
|
||||||
|
mock_edit.delay.side_effect = Exception("Task error")
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
# Should not raise
|
||||||
|
await collector.on_message_edit(before, after)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_servers_and_channels():
|
||||||
|
"""Test syncing servers and channels"""
|
||||||
|
guild1 = Mock()
|
||||||
|
guild2 = Mock()
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.guilds = [guild1, guild2]
|
||||||
|
|
||||||
|
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||||
|
await collector.sync_servers_and_channels()
|
||||||
|
|
||||||
|
assert mock_sync.call_count == 2
|
||||||
|
mock_sync.assert_any_call(guild1)
|
||||||
|
mock_sync.assert_any_call(guild2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.discord.collector.make_session")
|
||||||
|
async def test_refresh_metadata(mock_make_session):
|
||||||
|
"""Test metadata refresh"""
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=None)
|
||||||
|
mock_session.query.return_value.get.return_value = None
|
||||||
|
|
||||||
|
guild = Mock()
|
||||||
|
guild.id = 123
|
||||||
|
guild.name = "Test"
|
||||||
|
guild.channels = []
|
||||||
|
guild.members = []
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.guilds = [guild]
|
||||||
|
collector.intents = Mock()
|
||||||
|
collector.intents.members = False
|
||||||
|
|
||||||
|
result = await collector.refresh_metadata()
|
||||||
|
|
||||||
|
assert result["servers_updated"] == 1
|
||||||
|
assert result["channels_updated"] == 0
|
||||||
|
assert result["users_updated"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_id():
|
||||||
|
"""Test getting user by ID"""
|
||||||
|
user = Mock()
|
||||||
|
user.id = 123
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_user = Mock(return_value=user)
|
||||||
|
|
||||||
|
result = await collector.get_user(123)
|
||||||
|
|
||||||
|
assert result == user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_username():
|
||||||
|
"""Test getting user by username"""
|
||||||
|
member = Mock()
|
||||||
|
member.name = "testuser"
|
||||||
|
member.display_name = "Test User"
|
||||||
|
member.discriminator = "1234"
|
||||||
|
|
||||||
|
guild = Mock()
|
||||||
|
guild.members = [member]
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.guilds = [guild]
|
||||||
|
|
||||||
|
result = await collector.get_user("testuser")
|
||||||
|
|
||||||
|
assert result == member
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_not_found():
|
||||||
|
"""Test getting non-existent user"""
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.guilds = []
|
||||||
|
|
||||||
|
with patch.object(collector, "get_user", return_value=None):
|
||||||
|
with patch.object(
|
||||||
|
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
|
||||||
|
):
|
||||||
|
result = await collector.get_user(999)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_channel_by_name():
|
||||||
|
"""Test getting channel by name"""
|
||||||
|
channel = Mock(spec=discord.TextChannel)
|
||||||
|
channel.name = "general"
|
||||||
|
|
||||||
|
guild = Mock()
|
||||||
|
guild.channels = [channel]
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.guilds = [guild]
|
||||||
|
|
||||||
|
result = await collector.get_channel_by_name("general")
|
||||||
|
|
||||||
|
assert result == channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_channel_by_name_not_found():
|
||||||
|
"""Test getting non-existent channel"""
|
||||||
|
guild = Mock()
|
||||||
|
guild.channels = []
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.guilds = [guild]
|
||||||
|
|
||||||
|
result = await collector.get_channel_by_name("nonexistent")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_channel():
|
||||||
|
"""Test creating a channel"""
|
||||||
|
guild = Mock()
|
||||||
|
guild.name = "Test Server"
|
||||||
|
new_channel = Mock()
|
||||||
|
guild.create_text_channel = AsyncMock(return_value=new_channel)
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_guild = Mock(return_value=guild)
|
||||||
|
|
||||||
|
result = await collector.create_channel("new-channel", guild_id=123)
|
||||||
|
|
||||||
|
assert result == new_channel
|
||||||
|
guild.create_text_channel.assert_called_once_with("new-channel")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_channel_no_guild():
|
||||||
|
"""Test creating channel when no guild available"""
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_guild = Mock(return_value=None)
|
||||||
|
collector.guilds = []
|
||||||
|
|
||||||
|
result = await collector.create_channel("new-channel")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_dm_success():
|
||||||
|
"""Test sending DM successfully"""
|
||||||
|
user = Mock()
|
||||||
|
user.send = AsyncMock()
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_user = AsyncMock(return_value=user)
|
||||||
|
|
||||||
|
result = await collector.send_dm(123, "Hello!")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
user.send.assert_called_once_with("Hello!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_dm_user_not_found():
|
||||||
|
"""Test sending DM when user not found"""
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_user = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
result = await collector.send_dm(123, "Hello!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_dm_exception():
|
||||||
|
"""Test sending DM with exception"""
|
||||||
|
user = Mock()
|
||||||
|
user.send = AsyncMock(side_effect=Exception("Send failed"))
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_user = AsyncMock(return_value=user)
|
||||||
|
|
||||||
|
result = await collector.send_dm(123, "Hello!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
async def test_send_to_channel_success():
|
||||||
|
"""Test sending to channel successfully"""
|
||||||
|
channel = Mock()
|
||||||
|
channel.send = AsyncMock()
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_channel_by_name = AsyncMock(return_value=channel)
|
||||||
|
|
||||||
|
result = await collector.send_to_channel("general", "Announcement!")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
channel.send.assert_called_once_with("Announcement!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
|
async def test_send_to_channel_notifications_disabled():
|
||||||
|
"""Test sending to channel when notifications disabled"""
|
||||||
|
collector = MessageCollector()
|
||||||
|
|
||||||
|
result = await collector.send_to_channel("general", "Announcement!")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
async def test_send_to_channel_not_found():
|
||||||
|
"""Test sending to non-existent channel"""
|
||||||
|
collector = MessageCollector()
|
||||||
|
collector.get_channel_by_name = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
result = await collector.send_to_channel("nonexistent", "Message")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||||
|
async def test_run_collector():
|
||||||
|
"""Test running the collector"""
|
||||||
|
from memory.discord.collector import run_collector
|
||||||
|
|
||||||
|
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
|
||||||
|
mock_collector = Mock()
|
||||||
|
mock_collector.start = AsyncMock()
|
||||||
|
mock_collector_class.return_value = mock_collector
|
||||||
|
|
||||||
|
await run_collector()
|
||||||
|
|
||||||
|
mock_collector.start.assert_called_once_with("test_token")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||||
|
async def test_run_collector_no_token():
|
||||||
|
"""Test running collector without token"""
|
||||||
|
from memory.discord.collector import run_collector
|
||||||
|
|
||||||
|
# Should return early without raising
|
||||||
|
await run_collector()
|
||||||
607
tests/memory/workers/tasks/test_discord_tasks.py
Normal file
607
tests/memory/workers/tasks/test_discord_tasks.py
Normal file
@ -0,0 +1,607 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from memory.common.db.models import (
|
||||||
|
DiscordMessage,
|
||||||
|
DiscordUser,
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
)
|
||||||
|
from memory.workers.tasks import discord
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord_user(db_session):
|
||||||
|
"""Create a Discord user for testing."""
|
||||||
|
user = DiscordUser(
|
||||||
|
id=123456789,
|
||||||
|
username="testuser",
|
||||||
|
ignore_messages=False,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord_server(db_session):
|
||||||
|
"""Create a Discord server for testing."""
|
||||||
|
server = DiscordServer(
|
||||||
|
id=987654321,
|
||||||
|
name="Test Server",
|
||||||
|
ignore_messages=False,
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord_channel(db_session, mock_discord_server):
|
||||||
|
"""Create a Discord channel for testing."""
|
||||||
|
channel = DiscordChannel(
|
||||||
|
id=111222333,
|
||||||
|
name="test-channel",
|
||||||
|
channel_type="text",
|
||||||
|
server_id=mock_discord_server.id,
|
||||||
|
ignore_messages=False,
|
||||||
|
)
|
||||||
|
db_session.add(channel)
|
||||||
|
db_session.commit()
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_message_data(mock_discord_user, mock_discord_channel):
|
||||||
|
"""Sample message data for testing."""
|
||||||
|
return {
|
||||||
|
"message_id": 999888777,
|
||||||
|
"channel_id": mock_discord_channel.id,
|
||||||
|
"author_id": mock_discord_user.id,
|
||||||
|
"content": "This is a test Discord message with enough content to be processed.",
|
||||||
|
"sent_at": "2024-01-01T12:00:00Z",
|
||||||
|
"server_id": None,
|
||||||
|
"message_reference_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_prev_returns_previous_messages(
|
||||||
|
db_session, mock_discord_user, mock_discord_channel
|
||||||
|
):
|
||||||
|
"""Test that get_prev returns previous messages in order."""
|
||||||
|
# Create previous messages
|
||||||
|
msg1 = DiscordMessage(
|
||||||
|
message_id=1,
|
||||||
|
channel_id=mock_discord_channel.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content="First message",
|
||||||
|
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash1" + bytes(26),
|
||||||
|
)
|
||||||
|
msg2 = DiscordMessage(
|
||||||
|
message_id=2,
|
||||||
|
channel_id=mock_discord_channel.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content="Second message",
|
||||||
|
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash2" + bytes(26),
|
||||||
|
)
|
||||||
|
msg3 = DiscordMessage(
|
||||||
|
message_id=3,
|
||||||
|
channel_id=mock_discord_channel.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content="Third message",
|
||||||
|
sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash3" + bytes(26),
|
||||||
|
)
|
||||||
|
db_session.add_all([msg1, msg2, msg3])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Get previous messages before 10:15
|
||||||
|
result = discord.get_prev(
|
||||||
|
db_session,
|
||||||
|
mock_discord_channel.id,
|
||||||
|
datetime(2024, 1, 1, 10, 15, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0] == "testuser: First message"
|
||||||
|
assert result[1] == "testuser: Second message"
|
||||||
|
assert result[2] == "testuser: Third message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_prev_limits_context_window(
|
||||||
|
db_session, mock_discord_user, mock_discord_channel
|
||||||
|
):
|
||||||
|
"""Test that get_prev respects DISCORD_CONTEXT_WINDOW setting."""
|
||||||
|
# Create 15 messages (more than the default context window of 10)
|
||||||
|
for i in range(15):
|
||||||
|
msg = DiscordMessage(
|
||||||
|
message_id=i,
|
||||||
|
channel_id=mock_discord_channel.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content=f"Message {i}",
|
||||||
|
sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=f"hash{i}".encode() + bytes(27),
|
||||||
|
)
|
||||||
|
db_session.add(msg)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = discord.get_prev(
|
||||||
|
db_session,
|
||||||
|
mock_discord_channel.id,
|
||||||
|
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return last 10 messages
|
||||||
|
assert len(result) == 10
|
||||||
|
assert result[0] == "testuser: Message 5" # Oldest in window
|
||||||
|
assert result[-1] == "testuser: Message 14" # Most recent
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_prev_empty_channel(db_session, mock_discord_channel):
|
||||||
|
"""Test get_prev with no previous messages."""
|
||||||
|
result = discord.get_prev(
|
||||||
|
db_session,
|
||||||
|
mock_discord_channel.id,
|
||||||
|
datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_should_process_normal_message(
|
||||||
|
db_session, mock_discord_user, mock_discord_server, mock_discord_channel
|
||||||
|
):
|
||||||
|
"""Test should_process returns True for normal messages."""
|
||||||
|
message = DiscordMessage(
|
||||||
|
message_id=1,
|
||||||
|
channel_id=mock_discord_channel.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
server_id=mock_discord_server.id,
|
||||||
|
content="Test",
|
||||||
|
sent_at=datetime.now(timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash" + bytes(27),
|
||||||
|
)
|
||||||
|
db_session.add(message)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(message)
|
||||||
|
|
||||||
|
assert discord.should_process(message) is True
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||||
|
def test_should_process_disabled():
|
||||||
|
"""Test should_process returns False when processing is disabled."""
|
||||||
|
message = Mock()
|
||||||
|
assert discord.should_process(message) is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
|
def test_should_process_notifications_disabled():
|
||||||
|
"""Test should_process returns False when notifications are disabled."""
|
||||||
|
message = Mock()
|
||||||
|
assert discord.should_process(message) is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_should_process_server_ignored(
|
||||||
|
db_session, mock_discord_user, mock_discord_channel
|
||||||
|
):
|
||||||
|
"""Test should_process returns False when server has ignore_messages=True."""
|
||||||
|
server = DiscordServer(
|
||||||
|
id=123,
|
||||||
|
name="Ignored Server",
|
||||||
|
ignore_messages=True,
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
message = DiscordMessage(
|
||||||
|
message_id=1,
|
||||||
|
channel_id=mock_discord_channel.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
server_id=server.id,
|
||||||
|
content="Test",
|
||||||
|
sent_at=datetime.now(timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash" + bytes(27),
|
||||||
|
)
|
||||||
|
db_session.add(message)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(message)
|
||||||
|
|
||||||
|
assert discord.should_process(message) is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_should_process_channel_ignored(
|
||||||
|
db_session, mock_discord_user, mock_discord_server
|
||||||
|
):
|
||||||
|
"""Test should_process returns False when channel has ignore_messages=True."""
|
||||||
|
channel = DiscordChannel(
|
||||||
|
id=456,
|
||||||
|
name="ignored-channel",
|
||||||
|
channel_type="text",
|
||||||
|
server_id=mock_discord_server.id,
|
||||||
|
ignore_messages=True,
|
||||||
|
)
|
||||||
|
db_session.add(channel)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
message = DiscordMessage(
|
||||||
|
message_id=1,
|
||||||
|
channel_id=channel.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
server_id=mock_discord_server.id,
|
||||||
|
content="Test",
|
||||||
|
sent_at=datetime.now(timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash" + bytes(27),
|
||||||
|
)
|
||||||
|
db_session.add(message)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(message)
|
||||||
|
|
||||||
|
assert discord.should_process(message) is False
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_should_process_user_ignored(
|
||||||
|
db_session, mock_discord_server, mock_discord_channel
|
||||||
|
):
|
||||||
|
"""Test should_process returns False when user has ignore_messages=True."""
|
||||||
|
user = DiscordUser(
|
||||||
|
id=789,
|
||||||
|
username="ignoreduser",
|
||||||
|
ignore_messages=True,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
message = DiscordMessage(
|
||||||
|
message_id=1,
|
||||||
|
channel_id=mock_discord_channel.id,
|
||||||
|
discord_user_id=user.id,
|
||||||
|
server_id=mock_discord_server.id,
|
||||||
|
content="Test",
|
||||||
|
sent_at=datetime.now(timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash" + bytes(27),
|
||||||
|
)
|
||||||
|
db_session.add(message)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(message)
|
||||||
|
|
||||||
|
assert discord.should_process(message) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_discord_message_success(db_session, sample_message_data, qdrant):
|
||||||
|
"""Test successful Discord message addition."""
|
||||||
|
result = discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
assert "discordmessage_id" in result
|
||||||
|
|
||||||
|
# Verify the message was created in the database
|
||||||
|
message = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(message_id=sample_message_data["message_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert message is not None
|
||||||
|
assert message.content == sample_message_data["content"]
|
||||||
|
assert message.message_type == "default"
|
||||||
|
assert message.reply_to_message_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
|
||||||
|
"""Test adding a Discord message that is a reply."""
|
||||||
|
sample_message_data["message_reference_id"] = 111222333
|
||||||
|
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
message = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(message_id=sample_message_data["message_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert message.message_type == "reply"
|
||||||
|
assert message.reply_to_message_id == 111222333
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_discord_message_already_exists(db_session, sample_message_data, qdrant):
|
||||||
|
"""Test adding a message that already exists."""
|
||||||
|
# Add the message once
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
# Try to add it again
|
||||||
|
result = discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
assert result["status"] == "already_exists"
|
||||||
|
assert result["message_id"] == sample_message_data["message_id"]
|
||||||
|
|
||||||
|
# Verify no duplicate was created
|
||||||
|
messages = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(message_id=sample_message_data["message_id"])
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(messages) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_discord_message_with_context(
|
||||||
|
db_session, sample_message_data, mock_discord_user, qdrant
|
||||||
|
):
|
||||||
|
"""Test that message is added successfully when previous messages exist."""
|
||||||
|
# Add a previous message
|
||||||
|
prev_msg = DiscordMessage(
|
||||||
|
message_id=111111,
|
||||||
|
channel_id=sample_message_data["channel_id"],
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content="Previous message",
|
||||||
|
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"prev" + bytes(28),
|
||||||
|
)
|
||||||
|
db_session.add(prev_msg)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
message = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(message_id=sample_message_data["message_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert message is not None
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.discord.process_discord_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_add_discord_message_triggers_processing(
|
||||||
|
mock_process,
|
||||||
|
db_session,
|
||||||
|
sample_message_data,
|
||||||
|
mock_discord_server,
|
||||||
|
mock_discord_channel,
|
||||||
|
qdrant,
|
||||||
|
):
|
||||||
|
"""Test that add_discord_message triggers process_discord_message when conditions are met."""
|
||||||
|
mock_process.delay = Mock()
|
||||||
|
sample_message_data["server_id"] = mock_discord_server.id
|
||||||
|
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
# Verify process_discord_message.delay was called
|
||||||
|
mock_process.delay.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.discord.process_discord_message")
|
||||||
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||||
|
def test_add_discord_message_no_processing_when_disabled(
|
||||||
|
mock_process, db_session, sample_message_data, qdrant
|
||||||
|
):
|
||||||
|
"""Test that process_discord_message is not called when processing is disabled."""
|
||||||
|
mock_process.delay = Mock()
|
||||||
|
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
mock_process.delay.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_discord_message_success(db_session, sample_message_data, qdrant):
|
||||||
|
"""Test successful Discord message edit."""
|
||||||
|
# First add the message
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
# Edit it
|
||||||
|
new_content = (
|
||||||
|
"This is the edited content with enough text to be meaningful and processed."
|
||||||
|
)
|
||||||
|
edited_at = "2024-01-01T13:00:00Z"
|
||||||
|
|
||||||
|
result = discord.edit_discord_message(
|
||||||
|
sample_message_data["message_id"],
|
||||||
|
new_content,
|
||||||
|
edited_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
# Verify the message was updated
|
||||||
|
message = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(message_id=sample_message_data["message_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert message.content == new_content
|
||||||
|
assert message.edited_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_discord_message_not_found(db_session):
|
||||||
|
"""Test editing a message that doesn't exist."""
|
||||||
|
result = discord.edit_discord_message(
|
||||||
|
999999,
|
||||||
|
"New content",
|
||||||
|
"2024-01-01T13:00:00Z",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert result["error"] == "Message not found"
|
||||||
|
assert result["message_id"] == 999999
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_discord_message_updates_context(
|
||||||
|
db_session, sample_message_data, mock_discord_user, qdrant
|
||||||
|
):
|
||||||
|
"""Test that editing a message works correctly."""
|
||||||
|
# Add previous message and the message to be edited
|
||||||
|
prev_msg = DiscordMessage(
|
||||||
|
message_id=111111,
|
||||||
|
channel_id=sample_message_data["channel_id"],
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content="Previous message",
|
||||||
|
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"prev" + bytes(28),
|
||||||
|
)
|
||||||
|
db_session.add(prev_msg)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
# Edit the message
|
||||||
|
result = discord.edit_discord_message(
|
||||||
|
sample_message_data["message_id"],
|
||||||
|
"Edited content that should have context updated properly.",
|
||||||
|
"2024-01-01T13:00:00Z",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify message was updated
|
||||||
|
message = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(message_id=sample_message_data["message_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
message.content == "Edited content that should have context updated properly."
|
||||||
|
)
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
|
||||||
|
"""Test processing a Discord message."""
|
||||||
|
# Add a message first
|
||||||
|
add_result = discord.add_discord_message(**sample_message_data)
|
||||||
|
message_id = add_result["discordmessage_id"]
|
||||||
|
|
||||||
|
# Process it
|
||||||
|
result = discord.process_discord_message(message_id)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
assert result["message_id"] == message_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_discord_message_not_found(db_session):
|
||||||
|
"""Test processing a message that doesn't exist."""
|
||||||
|
result = discord.process_discord_message(999999)
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert result["error"] == "Message not found"
|
||||||
|
assert result["message_id"] == 999999
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sent_at_str,expected_hour",
|
||||||
|
[
|
||||||
|
("2024-01-01T12:00:00Z", 12),
|
||||||
|
("2024-01-01T00:00:00+00:00", 0),
|
||||||
|
("2024-01-01T23:59:59Z", 23),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_add_discord_message_datetime_parsing(
|
||||||
|
db_session, sample_message_data, sent_at_str, expected_hour, qdrant
|
||||||
|
):
|
||||||
|
"""Test that various datetime formats are parsed correctly."""
|
||||||
|
sample_message_data["sent_at"] = sent_at_str
|
||||||
|
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
message = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(message_id=sample_message_data["message_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert message.sent_at.hour == expected_hour
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_discord_message_unique_hash(db_session, sample_message_data, qdrant):
|
||||||
|
"""Test that message hash includes message_id for uniqueness."""
|
||||||
|
# Add first message
|
||||||
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
# Try to add another message with same content but different message_id
|
||||||
|
sample_message_data["message_id"] = 888777666
|
||||||
|
|
||||||
|
result = discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
|
# Should succeed because hash includes message_id
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
# Verify both messages exist
|
||||||
|
messages = (
|
||||||
|
db_session.query(DiscordMessage)
|
||||||
|
.filter_by(content=sample_message_data["content"])
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(messages) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_prev_only_returns_messages_from_same_channel(
|
||||||
|
db_session, mock_discord_user, mock_discord_server
|
||||||
|
):
|
||||||
|
"""Test that get_prev only returns messages from the specified channel."""
|
||||||
|
# Create two channels
|
||||||
|
channel1 = DiscordChannel(
|
||||||
|
id=111,
|
||||||
|
name="channel-1",
|
||||||
|
channel_type="text",
|
||||||
|
server_id=mock_discord_server.id,
|
||||||
|
)
|
||||||
|
channel2 = DiscordChannel(
|
||||||
|
id=222,
|
||||||
|
name="channel-2",
|
||||||
|
channel_type="text",
|
||||||
|
server_id=mock_discord_server.id,
|
||||||
|
)
|
||||||
|
db_session.add_all([channel1, channel2])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Add messages to both channels
|
||||||
|
msg1 = DiscordMessage(
|
||||||
|
message_id=1,
|
||||||
|
channel_id=channel1.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content="Message in channel 1",
|
||||||
|
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash1" + bytes(26),
|
||||||
|
)
|
||||||
|
msg2 = DiscordMessage(
|
||||||
|
message_id=2,
|
||||||
|
channel_id=channel2.id,
|
||||||
|
discord_user_id=mock_discord_user.id,
|
||||||
|
content="Message in channel 2",
|
||||||
|
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||||
|
modality="text",
|
||||||
|
sha256=b"hash2" + bytes(26),
|
||||||
|
)
|
||||||
|
db_session.add_all([msg1, msg2])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Get previous messages for channel 1
|
||||||
|
result = discord.get_prev(
|
||||||
|
db_session,
|
||||||
|
channel1.id, # type: ignore
|
||||||
|
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return message from channel 1
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "Message in channel 1" in result[0]
|
||||||
|
assert "Message in channel 2" not in result[0]
|
||||||
Loading…
x
Reference in New Issue
Block a user