From 1606348d8b176182ba0d23ac741b6aa191b90079 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 20 Oct 2025 03:47:13 +0200 Subject: [PATCH] discord integration --- ... => 20251013_142101_add_discord_models.py} | 57 +- .../20251020_014858_seperate_user__models.py | 145 +++++ docker-compose.yaml | 3 +- src/memory/api/MCP/base.py | 14 +- src/memory/api/MCP/oauth_provider.py | 46 +- src/memory/api/MCP/schedules.py | 26 +- src/memory/api/auth.py | 19 +- src/memory/common/db/models/__init__.py | 14 +- src/memory/common/db/models/discord.py | 121 ++++ .../common/db/models/scheduled_calls.py | 13 +- src/memory/common/db/models/source_items.py | 32 +- src/memory/common/db/models/sources.py | 71 --- src/memory/common/db/models/users.py | 87 ++- src/memory/common/discord.py | 20 +- src/memory/common/llms/__init__.py | 25 + src/memory/common/llms/anthropic_provider.py | 1 - src/memory/common/llms/base.py | 3 + src/memory/common/llms/openai_provider.py | 26 +- src/memory/common/llms/tools/discord.py | 231 ++++++++ src/memory/common/settings.py | 9 +- src/memory/discord/collector.py | 3 +- src/memory/discord/messages.py | 205 +++++++ src/memory/workers/tasks/discord.py | 118 +++- src/memory/workers/tasks/scheduled_calls.py | 18 +- tests/conftest.py | 44 +- tests/memory/common/llms/__init__.py | 0 .../llms/test_anthropic_event_parsing.py | 552 +++++++++++++++++ .../common/llms/test_anthropic_provider.py | 440 ++++++++++++++ tests/memory/common/llms/test_base.py | 270 +++++++++ .../common/llms/test_openai_event_parsing.py | 478 +++++++++++++++ .../common/llms/test_openai_provider.py | 561 ++++++++++++++++++ tools/add_user.py | 32 +- 32 files changed, 3472 insertions(+), 212 deletions(-) rename db/migrations/versions/{20251012_222827_add_discord_models.py => 20251013_142101_add_discord_models.py} (79%) create mode 100644 db/migrations/versions/20251020_014858_seperate_user__models.py create mode 100644 src/memory/common/db/models/discord.py create mode 100644 src/memory/common/llms/tools/discord.py create mode 100644 src/memory/discord/messages.py create mode 100644 tests/memory/common/llms/__init__.py create mode 100644 tests/memory/common/llms/test_anthropic_event_parsing.py create mode 100644 tests/memory/common/llms/test_anthropic_provider.py create mode 100644 tests/memory/common/llms/test_base.py create mode 100644 tests/memory/common/llms/test_openai_event_parsing.py create mode 100644 tests/memory/common/llms/test_openai_provider.py diff --git a/db/migrations/versions/20251012_222827_add_discord_models.py b/db/migrations/versions/20251013_142101_add_discord_models.py similarity index 79% rename from db/migrations/versions/20251012_222827_add_discord_models.py rename to db/migrations/versions/20251013_142101_add_discord_models.py index deb73e1..cd054eb 100644 --- a/db/migrations/versions/20251012_222827_add_discord_models.py +++ b/db/migrations/versions/20251013_142101_add_discord_models.py @@ -1,8 +1,8 @@ """add_discord_models -Revision ID: a8c8e8b17179 +Revision ID: 7c6169fba146 Revises: c86079073c1d -Create Date: 2025-10-12 22:28:27.856164 +Create Date: 2025-10-13 14:21:01.080948 """ @@ -13,7 +13,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "a8c8e8b17179" +revision: str = "7c6169fba146" down_revision: Union[str, None] = "c86079073c1d" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -26,12 +26,6 @@ def upgrade() -> None: sa.Column("name", sa.Text(), nullable=False), sa.Column("description", sa.Text(), nullable=True), sa.Column("member_count", sa.Integer(), nullable=True), - sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True), - sa.Column( - "ignore_messages", sa.Boolean(), server_default="false", nullable=True - ), - sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True), - sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True), sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True), sa.Column( "created_at", @@ -45,6 +39,17 @@ def upgrade() -> None: server_default=sa.text("now()"), nullable=True, ), + sa.Column( + "track_messages", sa.Boolean(), server_default="true", nullable=False + ), + sa.Column("ignore_messages", sa.Boolean(), nullable=True), + sa.Column( + "allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column( + "disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column("summary", sa.Text(), nullable=True), sa.PrimaryKeyConstraint("id"), ) op.create_index( @@ -59,12 +64,6 @@ def upgrade() -> None: sa.Column("server_id", sa.BigInteger(), nullable=True), sa.Column("name", sa.Text(), nullable=False), sa.Column("channel_type", sa.Text(), nullable=False), - sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True), - sa.Column( - "ignore_messages", sa.Boolean(), server_default="false", nullable=True - ), - sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True), - sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), @@ -77,6 +76,17 @@ def upgrade() -> None: server_default=sa.text("now()"), nullable=True, ), + sa.Column( + "track_messages", sa.Boolean(), server_default="true", nullable=False + ), + sa.Column("ignore_messages", sa.Boolean(), nullable=True), + sa.Column( + "allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column( + "disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column("summary", sa.Text(), nullable=True), sa.ForeignKeyConstraint( ["server_id"], ["discord_servers.id"], @@ -92,12 +102,6 @@ def upgrade() -> None: sa.Column("username", sa.Text(), nullable=False), sa.Column("display_name", sa.Text(), nullable=True), sa.Column("system_user_id", sa.Integer(), nullable=True), - sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True), - sa.Column( - "ignore_messages", sa.Boolean(), server_default="false", nullable=True - ), - sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True), - sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), @@ -110,6 +114,17 @@ def upgrade() -> None: server_default=sa.text("now()"), nullable=True, ), + sa.Column( + "track_messages", sa.Boolean(), server_default="true", nullable=False + ), + sa.Column("ignore_messages", sa.Boolean(), nullable=True), + sa.Column( + "allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column( + "disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column("summary", sa.Text(), nullable=True), sa.ForeignKeyConstraint( ["system_user_id"], ["users.id"], diff --git a/db/migrations/versions/20251020_014858_seperate_user__models.py b/db/migrations/versions/20251020_014858_seperate_user__models.py new file mode 100644 index 0000000..274a43c --- /dev/null +++ b/db/migrations/versions/20251020_014858_seperate_user__models.py @@ -0,0 +1,145 @@ +"""seperate_user__models + +Revision ID: 35a2c1b610b6 +Revises: 7c6169fba146 +Create Date: 2025-10-20 01:48:58.537881 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "35a2c1b610b6" +down_revision: Union[str, None] = "7c6169fba146" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "discord_message", sa.Column("from_id", sa.BigInteger(), nullable=False) + ) + op.add_column( + "discord_message", sa.Column("recipient_id", sa.BigInteger(), nullable=False) + ) + op.drop_index("discord_message_user_idx", table_name="discord_message") + op.create_index( + "discord_message_from_idx", "discord_message", ["from_id"], unique=False + ) + op.create_index( + "discord_message_recipient_idx", + "discord_message", + ["recipient_id"], + unique=False, + ) + op.drop_constraint( + "discord_message_discord_user_id_fkey", "discord_message", type_="foreignkey" + ) + op.create_foreign_key( + "discord_message_from_id_fkey", + "discord_message", + "discord_users", + ["from_id"], + ["id"], + ) + op.create_foreign_key( + "discord_message_recipient_id_fkey", + "discord_message", + "discord_users", + ["recipient_id"], + ["id"], + ) + op.drop_column("discord_message", "discord_user_id") + op.add_column( + "scheduled_llm_calls", + sa.Column("discord_channel_id", sa.BigInteger(), nullable=True), + ) + op.add_column( + "scheduled_llm_calls", + sa.Column("discord_user_id", sa.BigInteger(), nullable=True), + ) + op.create_foreign_key( + "scheduled_llm_calls_discord_user_id_fkey", + "scheduled_llm_calls", + "discord_users", + ["discord_user_id"], + ["id"], + ) + op.create_foreign_key( + "scheduled_llm_calls_discord_channel_id_fkey", + "scheduled_llm_calls", + "discord_channels", + ["discord_channel_id"], + ["id"], + ) + op.drop_column("scheduled_llm_calls", "discord_user") + op.drop_column("scheduled_llm_calls", "discord_channel") + op.add_column( + "users", + sa.Column("user_type", sa.String(), nullable=False, server_default="human"), + ) + op.add_column("users", sa.Column("api_key", sa.String(), nullable=True)) + op.alter_column("users", "password_hash", existing_type=sa.VARCHAR(), nullable=True) + op.create_unique_constraint("users_api_key_key", "users", ["api_key"]) + op.drop_column("users", "discord_user_id") + + +def downgrade() -> None: + op.add_column( + "users", + sa.Column("discord_user_id", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + op.drop_constraint("users_api_key_key", "users", type_="unique") + op.alter_column( + "users", "password_hash", existing_type=sa.VARCHAR(), nullable=False + ) + op.drop_column("users", "api_key") + op.drop_column("users", "user_type") + op.add_column( + "scheduled_llm_calls", + sa.Column("discord_channel", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + op.add_column( + "scheduled_llm_calls", + sa.Column("discord_user", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + op.drop_constraint( + "scheduled_llm_calls_discord_user_id_fkey", + "scheduled_llm_calls", + type_="foreignkey", + ) + op.drop_constraint( + "scheduled_llm_calls_discord_channel_id_fkey", + "scheduled_llm_calls", + type_="foreignkey", + ) + op.drop_column("scheduled_llm_calls", "discord_user_id") + op.drop_column("scheduled_llm_calls", "discord_channel_id") + op.add_column( + "discord_message", + sa.Column("discord_user_id", sa.BIGINT(), autoincrement=False, nullable=False), + ) + op.drop_constraint( + "discord_message_from_id_fkey", "discord_message", type_="foreignkey" + ) + op.drop_constraint( + "discord_message_recipient_id_fkey", "discord_message", type_="foreignkey" + ) + op.create_foreign_key( + "discord_message_discord_user_id_fkey", + "discord_message", + "discord_users", + ["discord_user_id"], + ["id"], + ) + op.drop_index("discord_message_recipient_idx", table_name="discord_message") + op.drop_index("discord_message_from_idx", table_name="discord_message") + op.create_index( + "discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False + ) + op.drop_column("discord_message", "recipient_id") + op.drop_column("discord_message", "from_id") diff --git a/docker-compose.yaml b/docker-compose.yaml index dee38e3..2bc0945 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -50,6 +50,8 @@ x-worker-base: &worker-base OPENAI_API_KEY_FILE: /run/secrets/openai_key ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key VOYAGE_API_KEY: ${VOYAGE_API_KEY} + DISCORD_COLLECTOR_SERVER_URL: ingest-hub + DISCORD_COLLECTOR_PORT: 8003 secrets: [ postgres_password, openai_key, anthropic_key, ssh_private_key, ssh_public_key, ssh_known_hosts ] read_only: true tmpfs: @@ -183,7 +185,6 @@ services: dockerfile: docker/ingest_hub/Dockerfile environment: <<: *worker-env - DISCORD_API_PORT: 8000 DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN} DISCORD_NOTIFICATIONS_ENABLED: true DISCORD_COLLECTOR_ENABLED: true diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index f7cda67..aaa7a25 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -25,7 +25,7 @@ from memory.api.MCP.oauth_provider import ( from memory.common import settings from memory.common.db.connection import make_session from memory.common.db.models import OAuthState, UserSession -from memory.common.db.models.users import User +from memory.common.db.models.users import HumanUser logger = logging.getLogger(__name__) @@ -126,7 +126,11 @@ async def handle_login(request: Request): key: value for key, value in form.items() if key not in ["email", "password"] } with make_session() as session: - user = session.query(User).filter(User.email == form.get("email")).first() + user = ( + session.query(HumanUser) + .filter(HumanUser.email == form.get("email")) + .first() + ) if not user or not user.is_valid_password(str(form.get("password", ""))): logger.warning("Login failed - invalid credentials") return login_form(request, oauth_params, "Invalid email or password") @@ -144,11 +148,7 @@ def get_current_user() -> dict: return {"authenticated": False} with make_session() as session: - user_session = ( - session.query(UserSession) - .filter(UserSession.id == access_token.token) - .first() - ) + user_session = session.query(UserSession).get(access_token.token) if user_session and user_session.user: user_info = user_session.user.serialize() diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index ca0fd39..8ba2c63 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -21,6 +21,7 @@ from memory.common.db.models.users import ( OAuthRefreshToken, OAuthState, User, + BotUser, UserSession, ) from memory.common.db.models.users import ( @@ -92,7 +93,7 @@ def create_oauth_token( """Create an OAuth token response.""" return OAuthToken( access_token=access_token, - token_type="bearer", + token_type="Bearer", expires_in=ACCESS_TOKEN_LIFETIME, refresh_token=refresh_token, scope=" ".join(scopes), @@ -310,26 +311,37 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): return token async def load_access_token(self, token: str) -> Optional[AccessToken]: - """Load and validate an access token.""" + """Load and validate an access token (or bot API key).""" with make_session() as session: - now = datetime.now(timezone.utc).replace( - tzinfo=None - ) # Make naive for DB comparison - - # Query for active (non-expired) session + # Try as OAuth access token first user_session = session.query(UserSession).get(token) - if not user_session: - return None + if user_session: + now = datetime.now(timezone.utc).replace( + tzinfo=None + ) # Make naive for DB comparison - if user_session.expires_at < now: - return None + if user_session.expires_at < now: + return None - return AccessToken( - token=token, - client_id=user_session.oauth_state.client_id, - scopes=user_session.oauth_state.scopes, - expires_at=int(user_session.expires_at.timestamp()), - ) + return AccessToken( + token=token, + client_id=user_session.oauth_state.client_id, + scopes=user_session.oauth_state.scopes, + expires_at=int(user_session.expires_at.timestamp()), + ) + + # Try as bot API key + bot = session.query(User).filter(User.api_key == token).first() + if bot: + logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key") + return AccessToken( + token=token, + client_id=cast(str, bot.name or bot.email), + scopes=["read", "write"], # Bots get full access + expires_at=2147483647, # Far future (2038) + ) + + return None async def load_refresh_token( self, client: OAuthClientInformationFull, refresh_token: str diff --git a/src/memory/api/MCP/schedules.py b/src/memory/api/MCP/schedules.py index 78740b6..169fe07 100644 --- a/src/memory/api/MCP/schedules.py +++ b/src/memory/api/MCP/schedules.py @@ -9,7 +9,9 @@ from typing import Any from memory.api.MCP.base import get_current_user from memory.common.db.connection import make_session from memory.common.db.models import ScheduledLLMCall +from memory.common.db.models.discord import DiscordChannel, DiscordUser from memory.api.MCP.base import mcp +from memory.discord.schedule import schedule_discord_message logger = logging.getLogger(__name__) @@ -17,7 +19,7 @@ logger = logging.getLogger(__name__) @mcp.tool() async def schedule_message( scheduled_time: str, - message: str | None = None, + message: str, model: str | None = None, topic: str | None = None, discord_channel: str | None = None, @@ -56,7 +58,8 @@ async def schedule_message( if not user_id: raise ValueError("User not found") - discord_user = current_user.get("user", {}).get("discord_user_id") + discord_users = current_user.get("user", {}).get("discord_users") + discord_user = discord_users and next(iter(discord_users.keys()), None) if not discord_user and not discord_channel: raise ValueError("Either discord_user or discord_channel must be provided") @@ -69,27 +72,20 @@ async def schedule_message( except ValueError: raise ValueError("Invalid datetime format") - # Validate that the scheduled time is in the future - # Compare with naive datetime since we store naive in the database - current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None) - if scheduled_dt <= current_time_naive: - raise ValueError("Scheduled time must be in the future") - with make_session() as session: - # Create the scheduled call - scheduled_call = ScheduledLLMCall( - user_id=user_id, + scheduled_call = schedule_discord_message( + session=session, scheduled_time=scheduled_dt, message=message, - topic=topic, + user_id=current_user.get("user", {}).get("user_id"), model=model, - system_prompt=system_prompt, + topic=topic, discord_channel=discord_channel, discord_user=discord_user, - data=metadata or {}, + system_prompt=system_prompt, + metadata=metadata, ) - session.add(scheduled_call) session.commit() return { diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index a86edc0..14a79ca 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -7,7 +7,7 @@ from memory.common import settings from sqlalchemy.orm import Session as DBSession, scoped_session from memory.common.db.connection import get_session, make_session -from memory.common.db.models.users import User, UserSession +from memory.common.db.models.users import User, HumanUser, BotUser, UserSession logger = logging.getLogger(__name__) @@ -91,14 +91,14 @@ def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> return user -def create_user(email: str, password: str, name: str, db: DBSession) -> User: - """Create a new user""" +def create_user(email: str, password: str, name: str, db: DBSession) -> HumanUser: + """Create a new human user""" # Check if user already exists existing_user = db.query(User).filter(User.email == email).first() if existing_user: raise HTTPException(status_code=400, detail="User already exists") - user = User.create_with_password(email, name, password) + user = HumanUser.create_with_password(email, name, password) db.add(user) db.commit() db.refresh(user) @@ -106,14 +106,19 @@ def create_user(email: str, password: str, name: str, db: DBSession) -> User: return user -def authenticate_user(email: str, password: str, db: DBSession) -> User | None: - """Authenticate a user by email and password""" - user = db.query(User).filter(User.email == email).first() +def authenticate_user(email: str, password: str, db: DBSession) -> HumanUser | None: + """Authenticate a human user by email and password""" + user = db.query(HumanUser).filter(HumanUser.email == email).first() if user and user.is_valid_password(password): return user return None +def authenticate_bot(api_key: str, db: DBSession) -> BotUser | None: + """Authenticate a bot by API key""" + return db.query(BotUser).filter(BotUser.api_key == api_key).first() + + @router.api_route("/logout", methods=["GET", "POST"]) def logout(request: Request, db: DBSession = Depends(get_session)): """Logout and clear session""" diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index 7b1f778..140fa62 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -30,6 +30,11 @@ from memory.common.db.models.source_items import ( NotePayload, ForumPostPayload, ) +from memory.common.db.models.discord import ( + DiscordServer, + DiscordChannel, + DiscordUser, +) from memory.common.db.models.observations import ( ObservationContradiction, ReactionPattern, @@ -41,12 +46,12 @@ from memory.common.db.models.sources import ( Book, ArticleFeed, EmailAccount, - DiscordServer, - DiscordChannel, - DiscordUser, ) from memory.common.db.models.users import ( User, + HumanUser, + BotUser, + DiscordBotUser, UserSession, OAuthClientInformation, OAuthState, @@ -103,6 +108,9 @@ __all__ = [ "DiscordUser", # Users "User", + "HumanUser", + "BotUser", + "DiscordBotUser", "UserSession", "OAuthClientInformation", "OAuthState", diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py new file mode 100644 index 0000000..544ffb0 --- /dev/null +++ b/src/memory/common/db/models/discord.py @@ -0,0 +1,121 @@ +""" +Database models for the Discord system. +""" + +import textwrap + +from sqlalchemy import ( + ARRAY, + BigInteger, + Boolean, + Column, + DateTime, + ForeignKey, + Index, + Integer, + Text, + func, +) +from sqlalchemy.orm import relationship + +from memory.common.db.models.base import Base + + +class MessageProcessor: + track_messages = Column(Boolean, nullable=False, server_default="true") + ignore_messages = Column(Boolean, nullable=True, default=False) + + allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + + summary = Column( + Text, + nullable=True, + doc=textwrap.dedent( + """ + A summary of this processor, made by and for AI systems. + + The idea here is that AI systems can use this summary to keep notes on the given processor. + These should automatically be injected into the context of the messages that are processed by this processor. + """ + ), + ) + + def as_xml(self) -> str: + return ( + textwrap.dedent(""" + <{type}> + {name} + {summary} + + """) + .format( + type=self.__class__.__tablename__[8:], # type: ignore + name=getattr(self, "name", None) or getattr(self, "username", None), + summary=self.summary, + ) + .strip() + ) + + +class DiscordServer(Base, MessageProcessor): + """Discord server configuration and metadata""" + + __tablename__ = "discord_servers" + + id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID + name = Column(Text, nullable=False) + description = Column(Text) + member_count = Column(Integer) + + # Collection settings + last_sync_at = Column(DateTime(timezone=True)) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + channels = relationship( + "DiscordChannel", back_populates="server", cascade="all, delete-orphan" + ) + + __table_args__ = ( + Index("discord_servers_active_idx", "track_messages", "last_sync_at"), + ) + + +class DiscordChannel(Base, MessageProcessor): + """Discord channel metadata and configuration""" + + __tablename__ = "discord_channels" + + id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID + server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True) + name = Column(Text, nullable=False) + channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm" + + # Collection settings (null = inherit from server) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + server = relationship("DiscordServer", back_populates="channels") + __table_args__ = (Index("discord_channels_server_idx", "server_id"),) + + +class DiscordUser(Base, MessageProcessor): + """Discord user metadata and preferences""" + + __tablename__ = "discord_users" + + id = Column(BigInteger, primary_key=True) # Discord user snowflake ID + username = Column(Text, nullable=False) + display_name = Column(Text) + + # Link to system user if registered + system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + + # Basic DM settings + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + system_user = relationship("User", back_populates="discord_users") + + __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),) diff --git a/src/memory/common/db/models/scheduled_calls.py b/src/memory/common/db/models/scheduled_calls.py index bbd23d1..75d2184 100644 --- a/src/memory/common/db/models/scheduled_calls.py +++ b/src/memory/common/db/models/scheduled_calls.py @@ -7,6 +7,7 @@ from sqlalchemy import ( String, DateTime, ForeignKey, + BigInteger, JSON, Text, ) @@ -37,8 +38,10 @@ class ScheduledLLMCall(Base): allowed_tools = Column(JSON, nullable=True) # List of allowed tool names # Discord configuration - discord_channel = Column(String, nullable=True) - discord_user = Column(String, nullable=True) + discord_channel_id = Column( + BigInteger, ForeignKey("discord_channels.id"), nullable=True + ) + discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=True) # Execution status and results status = Column( @@ -55,6 +58,8 @@ class ScheduledLLMCall(Base): # Relationships user = relationship("User") + discord_channel = relationship("DiscordChannel", foreign_keys=[discord_channel_id]) + discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id]) def serialize(self) -> Dict[str, Any]: def print_datetime(dt: datetime | None) -> str | None: @@ -73,8 +78,8 @@ class ScheduledLLMCall(Base): "message": self.message, "system_prompt": self.system_prompt, "allowed_tools": self.allowed_tools, - "discord_channel": self.discord_channel, - "discord_user": self.discord_user, + "discord_channel": self.discord_channel and self.discord_channel.name, + "discord_user": self.discord_user and self.discord_user.username, "status": self.status, "response": self.response, "error_message": self.error_message, diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index 0b4ca96..fd0eac1 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -286,7 +286,8 @@ class DiscordMessage(SourceItem): sent_at = Column(DateTime(timezone=True), nullable=False) server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True) channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False) - discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False) + from_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False) + recipient_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False) message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID # Discord-specific metadata @@ -303,11 +304,33 @@ class DiscordMessage(SourceItem): channel = relationship("DiscordChannel", foreign_keys=[channel_id]) server = relationship("DiscordServer", foreign_keys=[server_id]) - discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id]) + from_user = relationship("DiscordUser", foreign_keys=[from_id]) + recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id]) + + @property + def allowed_tools(self) -> list[str]: + return ( + (self.channel.allowed_tools if self.channel else []) + + (self.from_user.allowed_tools if self.from_user else []) + + (self.server.allowed_tools if self.server else []) + ) + + @property + def disallowed_tools(self) -> list[str]: + return ( + (self.channel.disallowed_tools if self.channel else []) + + (self.from_user.disallowed_tools if self.from_user else []) + + (self.server.disallowed_tools if self.server else []) + ) + + def tool_allowed(self, tool: str) -> bool: + return not (self.disallowed_tools and tool in self.disallowed_tools) and ( + not self.allowed_tools or tool in self.allowed_tools + ) @property def title(self) -> str: - return f"{self.discord_user.username}: {self.content}" + return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" __mapper_args__ = { "polymorphic_identity": "discord_message", @@ -320,7 +343,8 @@ class DiscordMessage(SourceItem): "server_id", "channel_id", ), - Index("discord_message_user_idx", "discord_user_id"), + Index("discord_message_from_idx", "from_id"), + Index("discord_message_recipient_idx", "recipient_id"), ) def _chunk_contents(self) -> Sequence[extract.DataChunk]: diff --git a/src/memory/common/db/models/sources.py b/src/memory/common/db/models/sources.py index 1cbb9e4..77e0c3e 100644 --- a/src/memory/common/db/models/sources.py +++ b/src/memory/common/db/models/sources.py @@ -125,74 +125,3 @@ class EmailAccount(Base): Index("email_accounts_active_idx", "active", "last_sync_at"), Index("email_accounts_tags_idx", "tags", postgresql_using="gin"), ) - - -class MessageProcessor: - track_messages = Column(Boolean, nullable=False, server_default="true") - ignore_messages = Column(Boolean, nullable=True, default=False) - - allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") - disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") - - -class DiscordServer(Base, MessageProcessor): - """Discord server configuration and metadata""" - - __tablename__ = "discord_servers" - - id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID - name = Column(Text, nullable=False) - description = Column(Text) - member_count = Column(Integer) - - # Collection settings - last_sync_at = Column(DateTime(timezone=True)) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now()) - - channels = relationship( - "DiscordChannel", back_populates="server", cascade="all, delete-orphan" - ) - - __table_args__ = ( - Index("discord_servers_active_idx", "track_messages", "last_sync_at"), - ) - - -class DiscordChannel(Base, MessageProcessor): - """Discord channel metadata and configuration""" - - __tablename__ = "discord_channels" - - id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID - server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True) - name = Column(Text, nullable=False) - channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm" - - # Collection settings (null = inherit from server) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now()) - - server = relationship("DiscordServer", back_populates="channels") - __table_args__ = (Index("discord_channels_server_idx", "server_id"),) - - -class DiscordUser(Base, MessageProcessor): - """Discord user metadata and preferences""" - - __tablename__ = "discord_users" - - id = Column(BigInteger, primary_key=True) # Discord user snowflake ID - username = Column(Text, nullable=False) - display_name = Column(Text) - - # Link to system user if registered - system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) - - # Basic DM settings - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now()) - - system_user = relationship("User", back_populates="discord_users") - - __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),) diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py index 8228cee..09f8dd8 100644 --- a/src/memory/common/db/models/users.py +++ b/src/memory/common/db/models/users.py @@ -2,7 +2,6 @@ import hashlib import secrets from typing import cast import uuid -from datetime import datetime, timezone from sqlalchemy.orm import Session from memory.common.db.models.base import Base from sqlalchemy import ( @@ -14,6 +13,7 @@ from sqlalchemy import ( Boolean, ARRAY, Numeric, + CheckConstraint, ) from sqlalchemy.sql import func from sqlalchemy.orm import relationship @@ -36,12 +36,21 @@ def verify_password(password: str, password_hash: str) -> bool: class User(Base): __tablename__ = "users" + __table_args__ = ( + CheckConstraint( + "password_hash IS NOT NULL OR api_key IS NOT NULL", + name="user_has_auth_method", + ), + ) id = Column(Integer, primary_key=True) name = Column(String, nullable=False) email = Column(String, nullable=False, unique=True) - password_hash = Column(String, nullable=False) - discord_user_id = Column(String, nullable=True) + user_type = Column(String, nullable=False) # Discriminator column + + # Make these nullable since subclasses will use them selectively + password_hash = Column(String, nullable=True) + api_key = Column(String, nullable=True, unique=True) # Relationship to sessions sessions = relationship( @@ -52,22 +61,86 @@ class User(Base): ) discord_users = relationship("DiscordUser", back_populates="system_user") + __mapper_args__ = { + "polymorphic_on": user_type, + "polymorphic_identity": "user", + } + def serialize(self) -> dict: return { "user_id": self.id, "name": self.name, "email": self.email, - "discord_user_id": self.discord_user_id, + "user_type": self.user_type, + "discord_users": { + discord_user.id: discord_user.username + for discord_user in self.discord_users + }, } + +class HumanUser(User): + """Human user with password authentication""" + + __mapper_args__ = { + "polymorphic_identity": "human", + } + def is_valid_password(self, password: str) -> bool: """Check if the provided password is valid for this user""" return verify_password(password, cast(str, self.password_hash)) @classmethod - def create_with_password(cls, email: str, name: str, password: str) -> "User": - """Create a new user with a hashed password""" - return cls(email=email, name=name, password_hash=hash_password(password)) + def create_with_password(cls, email: str, name: str, password: str) -> "HumanUser": + """Create a new human user with a hashed password""" + return cls( + email=email, + name=name, + password_hash=hash_password(password), + user_type="human", + ) + + +class BotUser(User): + """Bot user with API key authentication""" + + __mapper_args__ = { + "polymorphic_identity": "bot", + } + + @classmethod + def create_with_api_key( + cls, name: str, email: str, api_key: str | None = None + ) -> "BotUser": + """Create a new bot user with an API key""" + if api_key is None: + api_key = f"bot_{secrets.token_hex(32)}" + return cls( + name=name, + email=email, + api_key=api_key, + user_type=cls.__mapper_args__["polymorphic_identity"], + ) + + +class DiscordBotUser(BotUser): + """Bot user with API key authentication""" + + __mapper_args__ = { + "polymorphic_identity": "discord_bot", + } + + @classmethod + def create_with_api_key( + cls, + discord_users: list, + name: str, + email: str, + api_key: str | None = None, + ) -> "DiscordBotUser": + bot = super().create_with_api_key(name, email, api_key) + bot.discord_users = discord_users + return bot class UserSession(Base): diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py index 92482bc..2a1a026 100644 --- a/src/memory/common/discord.py +++ b/src/memory/common/discord.py @@ -25,7 +25,7 @@ def send_dm(user_identifier: str, message: str) -> bool: try: response = requests.post( f"{get_api_url()}/send_dm", - json={"user_identifier": user_identifier, "message": message}, + json={"user": user_identifier, "message": message}, timeout=10, ) response.raise_for_status() @@ -37,6 +37,24 @@ def send_dm(user_identifier: str, message: str) -> bool: return False +def send_to_channel(channel_name: str, message: str) -> bool: + """Send a DM via the Discord collector API""" + try: + response = requests.post( + f"{get_api_url()}/send_channel", + json={"channel_name": channel_name, "message": message}, + timeout=10, + ) + response.raise_for_status() + result = response.json() + print("Result", result) + return result.get("success", False) + + except requests.RequestException as e: + logger.error(f"Failed to send to channel {channel_name}: {e}") + return False + + def broadcast_message(channel_name: str, message: str) -> bool: """Send a message to a channel via the Discord collector API""" try: diff --git a/src/memory/common/llms/__init__.py b/src/memory/common/llms/__init__.py index b2d37cc..b84e337 100644 --- a/src/memory/common/llms/__init__.py +++ b/src/memory/common/llms/__init__.py @@ -81,3 +81,28 @@ def truncate(content: str, target_tokens: int) -> str: if len(content) > target_chars: return content[:target_chars].rsplit(" ", 1)[0] + "..." return content + + +# bla = 1 +# from memory.common.llms import * +# from memory.common.llms.tools.discord import make_discord_tools +# from memory.common.db.connection import make_session +# from memory.common.db.models import * + +# model = "anthropic/claude-sonnet-4-5" +# provider = create_provider(model=model) +# with make_session() as session: +# bot = session.query(DiscordBotUser).first() +# server = session.query(DiscordServer).first() +# channel = server.channels[0] +# tools = make_discord_tools(bot, None, channel, model) + +# def demo(msg: str): +# messages = [ +# Message( +# role=MessageRole.USER, +# content=msg, +# ) +# ] +# for m in provider.stream_with_tools(messages, tools): +# print(m) diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index 1969bfb..938d552 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -333,7 +333,6 @@ class AnthropicProvider(BaseLLMProvider): settings = settings or LLMSettings() kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) - print(kwargs) try: with self.client.messages.stream(**kwargs) as stream: current_tool_use: dict[str, Any] | None = None diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index 113e71f..6238daa 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -599,6 +599,9 @@ class BaseLLMProvider(ABC): tool_calls=tool_calls or None, ) + def as_messages(self, messages) -> list[Message]: + return [Message.user(text=msg) for msg in messages] + def create_provider( model: str | None = None, diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py index aa5beb3..594b459 100644 --- a/src/memory/common/llms/openai_provider.py +++ b/src/memory/common/llms/openai_provider.py @@ -150,7 +150,7 @@ class OpenAIProvider(BaseLLMProvider): def _convert_tools( self, tools: list[ToolDefinition] | None - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: """ Convert our tool definitions to OpenAI format. @@ -179,7 +179,7 @@ class OpenAIProvider(BaseLLMProvider): self, messages: list[Message], system_prompt: str | None, - tools: Optional[list[ToolDefinition]], + tools: list[ToolDefinition] | None, settings: LLMSettings, stream: bool = False, ) -> dict[str, Any]: @@ -270,7 +270,7 @@ class OpenAIProvider(BaseLLMProvider): self, chunk: Any, current_tool_call: dict[str, Any] | None, - ) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]: + ) -> tuple[list[StreamEvent], dict[str, Any] | None]: """ Handle a single streaming chunk and return events and updated tool state. @@ -325,9 +325,9 @@ class OpenAIProvider(BaseLLMProvider): def generate( self, messages: list[Message], - system_prompt: Optional[str] = None, - tools: Optional[list[ToolDefinition]] = None, - settings: Optional[LLMSettings] = None, + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, ) -> str: """Generate a non-streaming response.""" settings = settings or LLMSettings() @@ -374,9 +374,9 @@ class OpenAIProvider(BaseLLMProvider): async def agenerate( self, messages: list[Message], - system_prompt: Optional[str] = None, - tools: Optional[list[ToolDefinition]] = None, - settings: Optional[LLMSettings] = None, + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, ) -> str: """Generate a non-streaming response asynchronously.""" settings = settings or LLMSettings() @@ -394,9 +394,9 @@ class OpenAIProvider(BaseLLMProvider): async def astream( self, messages: list[Message], - system_prompt: Optional[str] = None, - tools: Optional[list[ToolDefinition]] = None, - settings: Optional[LLMSettings] = None, + system_prompt: str | None = None, + tools: list[ToolDefinition] | None = None, + settings: LLMSettings | None = None, ) -> AsyncIterator[StreamEvent]: """Generate a streaming response asynchronously.""" settings = settings or LLMSettings() @@ -406,7 +406,7 @@ class OpenAIProvider(BaseLLMProvider): try: stream = await self.async_client.chat.completions.create(**kwargs) - current_tool_call: Optional[dict[str, Any]] = None + current_tool_call: dict[str, Any] | None = None async for chunk in stream: events, current_tool_call = self._handle_stream_chunk( diff --git a/src/memory/common/llms/tools/discord.py b/src/memory/common/llms/tools/discord.py new file mode 100644 index 0000000..8492be3 --- /dev/null +++ b/src/memory/common/llms/tools/discord.py @@ -0,0 +1,231 @@ +"""Discord tool for interacting with Discord.""" + +import textwrap +from datetime import datetime +from typing import Literal, cast +from memory.discord.messages import ( + upsert_scheduled_message, + comm_channel_prompt, + previous_messages, +) +from sqlalchemy import BigInteger +from memory.common.db.connection import make_session +from memory.common.db.models import ( + DiscordServer, + DiscordChannel, + DiscordUser, + BotUser, +) +from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler + + +UpdateSummaryType = Literal["server", "channel", "user"] + + +def handle_update_summary_call( + type: UpdateSummaryType, item_id: BigInteger +) -> ToolHandler: + models = { + "server": DiscordServer, + "channel": DiscordChannel, + "user": DiscordUser, + } + + def handler(input: ToolInput = None) -> str: + if isinstance(input, dict): + summary = input.get("summary") or str(input) + else: + summary = str(input) + + try: + with make_session() as session: + model = models[type] + model = session.get(model, item_id) + model.summary = summary # type: ignore + session.commit() + except Exception as e: + return f"Error updating summary: {e}" + return "Updated summary" + + handler.__doc__ = textwrap.dedent(""" + Handle a {type} summary update tool call. + + Args: + summary: The new summary of the Discord {type} + + Returns: + Response string + """).format(type=type) + return handler + + +def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefinition: + return ToolDefinition( + name=f"update_{type}_summary", + description=textwrap.dedent(""" + Use this to update the summary of this Discord {type} that is added to your context. + + This will overwrite the previous summary. + """).format(type=type), + input_schema={ + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": f"The new summary of the Discord {type}", + } + }, + "required": [], + }, + function=handle_update_summary_call(type, item_id), + ) + + +def schedule_message( + user_id: int, + user: int | None, + channel: int | None, + model: str, + message: str, + date_time: datetime, +) -> str: + with make_session() as session: + call = upsert_scheduled_message( + session, + scheduled_time=date_time, + message=message, + user_id=user_id, + model=model, + discord_user=user, + discord_channel=channel, + system_prompt=comm_channel_prompt(session, user, channel), + ) + + session.commit() + return cast(str, call.id) + + +def make_message_scheduler( + bot: BotUser, user: int | None, channel: int | None, model: str +) -> ToolDefinition: + bot_id = cast(int, bot.id) + if user: + channel_type = "from your chat with this user" + elif channel: + channel_type = "in this channel" + else: + raise ValueError("Either user or channel must be provided") + + def handler(input: ToolInput) -> str: + if not isinstance(input, dict): + raise ValueError("Input must be a dictionary") + + try: + time = datetime.fromisoformat(input["date_time"]) + except ValueError: + raise ValueError("Invalid date time format") + except KeyError: + raise ValueError("Date time is required") + + return schedule_message(bot_id, user, channel, model, input["message"], time) + + return ToolDefinition( + name="schedule_message", + description=textwrap.dedent(""" + Use this to schedule a message to be sent to yourself. + + At the specified date and time, your message will be sent to you, along with the most + recent messages {channel_type}. + + Normally you will be called with any incoming messages. But sometimes you might want to be + able to trigger a call to yourself at a specific time, rather than waiting for the next call. + This tool allows you to do that. + So for example, if you were chatting with a Discord user, and you ask a question which needs to + be answered right away, you can use this tool to schedule a check in 5 minutes time, to remind + the user to answer the question. + """).format(channel_type=channel_type), + input_schema={ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The message to send", + }, + "date_time": { + "type": "string", + "description": "The date and time to send the message in ISO format (e.g., 2025-01-01T00:00:00Z)", + }, + }, + }, + function=handler, + ) + + +def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition: + if user: + channel_type = "from your chat with this user" + elif channel: + channel_type = "in this channel" + else: + raise ValueError("Either user or channel must be provided") + + def handler(input: ToolInput) -> str: + if not isinstance(input, dict): + raise ValueError("Input must be a dictionary") + try: + max_messages = int(input.get("max_messages") or 10) + offset = int(input.get("offset") or 0) + except ValueError: + raise ValueError("Max messages and offset must be integers") + + if max_messages <= 0: + raise ValueError("Max messages must be greater than 0") + if offset < 0: + raise ValueError("Offset must be greater than or equal to 0") + + with make_session() as session: + messages = previous_messages(session, user, channel, max_messages, offset) + return "\n\n".join([msg.title for msg in messages]) + + return ToolDefinition( + name="previous_messages", + description=f"Get the previous N messages {channel_type}.", + input_schema={ + "type": "object", + "properties": { + "max_messages": { + "type": "number", + "description": "The maximum number of messages to return", + "default": 10, + }, + "offset": { + "type": "number", + "description": "The number of messages to offset the result by", + "default": 0, + }, + }, + }, + function=handler, + ) + + +def make_discord_tools( + bot: BotUser, + author: DiscordUser | None, + channel: DiscordChannel | None, + model: str, +) -> dict[str, ToolDefinition]: + author_id = author and author.id + channel_id = channel and channel.id + tools = [ + make_message_scheduler(bot, author_id, channel_id, model), + make_prev_messages_tool(author_id, channel_id), + make_summary_tool("channel", channel_id), + ] + if author: + tools += [make_summary_tool("user", author_id)] + if channel and channel.server: + tools += [ + make_summary_tool("server", cast(BigInteger, channel.server_id)), + ] + return {tool.name: tool for tool in tools} diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index daf98bb..c5b61a9 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -133,7 +133,7 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"): ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip() else: ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") -SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307") +SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5") RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307") MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000)) @@ -173,11 +173,14 @@ DISCORD_NOTIFICATIONS_ENABLED = bool( boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN ) DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True) +DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5") +DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10)) + # Discord collector settings DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True) DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True) DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True) -DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8000)) -DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1") +DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003)) +DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0") DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10)) diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index f60a22f..c0bdc0c 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import Session, scoped_session from memory.common import settings from memory.common.db.connection import make_session -from memory.common.db.models.sources import ( +from memory.common.db.models import ( DiscordServer, DiscordChannel, DiscordUser, @@ -227,6 +227,7 @@ class MessageCollector(commands.Bot): message_id=message.id, channel_id=message.channel.id, author_id=message.author.id, + recipient_id=self.user and self.user.id, server_id=message.guild.id if message.guild else None, content=message.content or "", sent_at=message.created_at.isoformat(), diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py new file mode 100644 index 0000000..d64d02e --- /dev/null +++ b/src/memory/discord/messages.py @@ -0,0 +1,205 @@ +import logging +import textwrap +from datetime import datetime, timezone +from typing import Any, cast + +from sqlalchemy.orm import Session, scoped_session + +from memory.common.db.models import ( + DiscordChannel, + DiscordUser, + ScheduledLLMCall, + DiscordMessage, +) + +logger = logging.getLogger(__name__) + +DiscordEntity = DiscordChannel | DiscordUser | str | int | None + + +def resolve_discord_user( + session: Session | scoped_session, entity: DiscordEntity +) -> DiscordUser | None: + if not entity: + return None + if isinstance(entity, DiscordUser): + return entity + if isinstance(entity, int): + return session.get(DiscordUser, entity) + + entity = session.query(DiscordUser).filter(DiscordUser.username == entity).first() + if not entity: + entity = DiscordUser(id=entity, username=entity) + session.add(entity) + return entity + + +def resolve_discord_channel( + session: Session | scoped_session, entity: DiscordEntity +) -> DiscordChannel | None: + if not entity: + return None + if isinstance(entity, DiscordChannel): + return entity + if isinstance(entity, int): + return session.get(DiscordChannel, entity) + + return session.query(DiscordChannel).filter(DiscordChannel.name == entity).first() + + +def schedule_discord_message( + session: Session | scoped_session, + scheduled_time: datetime, + message: str, + user_id: int, + model: str | None = None, + topic: str | None = None, + discord_user: DiscordEntity = None, + discord_channel: DiscordEntity = None, + system_prompt: str | None = None, + metadata: dict[str, Any] | None = None, +) -> ScheduledLLMCall: + discord_user = resolve_discord_user(session, discord_user) + discord_channel = resolve_discord_channel(session, discord_channel) + if not discord_user and not discord_channel: + raise ValueError("Either discord_user or discord_channel must be provided") + + # Validate that the scheduled time is in the future + # Compare with naive datetime since we store naive in the database + current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None) + if scheduled_time.replace(tzinfo=None) <= current_time_naive: + raise ValueError("Scheduled time must be in the future") + + # Create the scheduled call + scheduled_call = ScheduledLLMCall( + user_id=user_id, + scheduled_time=scheduled_time, + message=message, + topic=topic, + model=model, + system_prompt=system_prompt, + discord_channel=resolve_discord_channel(session, discord_channel), + discord_user=resolve_discord_user(session, discord_user), + data=metadata or {}, + ) + + session.add(scheduled_call) + return scheduled_call + + +def upsert_scheduled_message( + session: Session | scoped_session, + scheduled_time: datetime, + message: str, + user_id: int, + model: str | None = None, + topic: str | None = None, + discord_user: DiscordEntity = None, + discord_channel: DiscordEntity = None, + system_prompt: str | None = None, + metadata: dict[str, Any] | None = None, +) -> ScheduledLLMCall: + discord_user = resolve_discord_user(session, discord_user) + discord_channel = resolve_discord_channel(session, discord_channel) + prev_call = ( + session.query(ScheduledLLMCall) + .filter( + ScheduledLLMCall.user_id == user_id, + ScheduledLLMCall.model == model, + ScheduledLLMCall.discord_user_id == (discord_user and discord_user.id), + ScheduledLLMCall.discord_channel_id + == (discord_channel and discord_channel.id), + ) + .first() + ) + naive_scheduled_time = scheduled_time.replace(tzinfo=None) + print(f"naive_scheduled_time: {naive_scheduled_time}") + print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}") + if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time: + prev_call.status = "cancelled" # type: ignore + + return schedule_discord_message( + session, + scheduled_time, + message, + user_id=user_id, + model=model, + topic=topic, + discord_user=discord_user, + discord_channel=discord_channel, + system_prompt=system_prompt, + metadata=metadata, + ) + + +def previous_messages( + session: Session | scoped_session, + user_id: int | None, + channel_id: int | None, + max_messages: int = 10, + offset: int = 0, +) -> list[DiscordMessage]: + messages = session.query(DiscordMessage) + if user_id: + messages = messages.filter(DiscordMessage.recipient_id == user_id) + if channel_id: + messages = messages.filter(DiscordMessage.channel_id == channel_id) + return list( + reversed( + messages.order_by(DiscordMessage.sent_at.desc()) + .offset(offset) + .limit(max_messages) + .all() + ) + ) + + +def comm_channel_prompt( + session: Session | scoped_session, + user: DiscordEntity, + channel: DiscordEntity, + max_messages: int = 10, +) -> str: + user = resolve_discord_user(session, user) + channel = resolve_discord_channel(session, channel) + + messages = previous_messages( + session, user and user.id, channel and channel.id, max_messages + ) + + server_context = "" + if channel and channel.server: + server_context = textwrap.dedent(""" + Here are your previous notes on the server: + + {summary} + + """).format(summary=channel.server.summary) + if channel: + server_context += textwrap.dedent(""" + Here are your previous notes on the channel: + + {summary} + + """).format(summary=channel.summary) + if messages: + server_context += textwrap.dedent(""" + Here are your previous notes on the users: + + {users} + + """).format( + users="\n".join({msg.from_user.as_xml() for msg in messages}), + ) + + return textwrap.dedent(""" + You are a bot communicating on Discord. + + {server_context} + + Whenever something worth remembering is said, you should add a note to the appropriate context - use + this to track your understanding of the conversation and those taking part in it. + + You will be given the last {max_messages} messages in the conversation. + Please react to them appropriately. You can return an empty response if you don't have anything to say. + """).format(server_context=server_context, max_messages=max_messages) diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index 093acd1..f146af0 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -4,25 +4,31 @@ Celery tasks for Discord message processing. import hashlib import logging +import textwrap from datetime import datetime from typing import Any -from memory.common.celery_app import app -from memory.common.db.connection import make_session -from memory.common.db.models import DiscordMessage, DiscordUser -from memory.workers.tasks.content_processing import ( - safe_task_execution, - check_content_exists, - create_task_result, - process_content_item, -) +from sqlalchemy import exc as sqlalchemy_exc +from sqlalchemy.orm import Session, scoped_session + +from memory.common import discord, settings from memory.common.celery_app import ( ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE, PROCESS_DISCORD_MESSAGE, + app, +) +from memory.common.db.connection import make_session +from memory.common.db.models import DiscordMessage, DiscordUser +from memory.common.llms.base import create_provider +from memory.common.llms.tools.discord import make_discord_tools +from memory.discord.messages import comm_channel_prompt, previous_messages +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_task_result, + process_content_item, + safe_task_execution, ) -from memory.common import settings -from sqlalchemy.orm import Session, scoped_session logger = logging.getLogger(__name__) @@ -32,7 +38,7 @@ def get_prev( ) -> list[str]: prev = ( session.query(DiscordUser.username, DiscordMessage.content) - .join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id) + .join(DiscordUser, DiscordMessage.from_id == DiscordUser.id) .filter( DiscordMessage.channel_id == channel_id, DiscordMessage.sent_at < sent_at, @@ -45,20 +51,54 @@ def get_prev( def should_process(message: DiscordMessage) -> bool: - return ( + if not ( settings.DISCORD_PROCESS_MESSAGES and settings.DISCORD_NOTIFICATIONS_ENABLED and not ( (message.server and message.server.ignore_messages) or (message.channel and message.channel.ignore_messages) - or (message.discord_user and message.discord_user.ignore_messages) + or (message.from_user and message.from_user.ignore_messages) ) - ) + ): + return False + + provider = create_provider(model=settings.SUMMARIZER_MODEL) + with make_session() as session: + system_prompt = comm_channel_prompt( + session, message.recipient_user, message.channel + ) + messages = previous_messages( + session, + message.recipient_user and message.recipient_user.id, + message.channel and message.channel.id, + max_messages=10, + ) + msg = textwrap.dedent(""" + Should you continue the conversation with the user? + Please return "yes" or "no" as: + + yes + + or + + no + + """) + response = provider.generate( + messages=provider.as_messages([m.title for m in messages] + [msg]), + system_prompt=system_prompt, + ) + return "yes" in "".join(response.lower().split()) @app.task(name=PROCESS_DISCORD_MESSAGE) @safe_task_execution def process_discord_message(message_id: int) -> dict[str, Any]: + """ + Process a Discord message. + + This task is queued by the Discord collector when messages are received. + """ logger.info(f"Processing Discord message {message_id}") with make_session() as session: @@ -71,7 +111,39 @@ def process_discord_message(message_id: int) -> dict[str, Any]: "message_id": message_id, } - print("Processing message", discord_message.id, discord_message.content) + tools = make_discord_tools( + discord_message.recipient_user, + discord_message.from_user, + discord_message.channel, + model=settings.DISCORD_MODEL, + ) + tools = { + name: tool + for name, tool in tools.items() + if discord_message.tool_allowed(name) + } + system_prompt = comm_channel_prompt( + session, discord_message.recipient_user, discord_message.channel + ) + messages = previous_messages( + session, + discord_message.recipient_user and discord_message.recipient_user.id, + discord_message.channel and discord_message.channel.id, + max_messages=10, + ) + provider = create_provider(model=settings.DISCORD_MODEL) + turn = provider.run_with_tools( + messages=provider.as_messages([m.title for m in messages]), + tools=tools, + system_prompt=system_prompt, + max_iterations=settings.DISCORD_MAX_TOOL_CALLS, + ) + if not turn.response: + pass + elif discord_message.channel.server: + discord.send_to_channel(discord_message.channel.name, turn.response) + else: + discord.send_dm(discord_message.from_user.username, turn.response) return { "status": "processed", @@ -88,6 +160,7 @@ def add_discord_message( content: str, sent_at: str, server_id: int | None = None, + recipient_id: int | None = None, message_reference_id: int | None = None, ) -> dict[str, Any]: """ @@ -108,7 +181,8 @@ def add_discord_message( channel_id=channel_id, sent_at=sent_at_dt, server_id=server_id, - discord_user_id=author_id, + from_id=author_id, + recipient_id=recipient_id, message_id=message_id, message_type="reply" if message_reference_id else "default", reply_to_message_id=message_reference_id, @@ -125,7 +199,15 @@ def add_discord_message( if channel_id: discord_message.messages_before = get_prev(session, channel_id, sent_at_dt) - result = process_content_item(discord_message, session) + try: + result = process_content_item(discord_message, session) + except sqlalchemy_exc.IntegrityError as e: + logger.error(f"Integrity error adding Discord message {message_id}: {e}") + return { + "status": "error", + "error": "Integrity error", + "message_id": message_id, + } if should_process(discord_message): process_discord_message.delay(discord_message.id) diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index 3248285..3ca6cad 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -37,12 +37,12 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): if len(message) > 1900: # Leave some buffer message = message[:1900] + "\n\n... (response truncated)" - if discord_user := cast(str, scheduled_call.discord_user): - logger.info(f"Sending DM to {discord_user}: {message}") - discord.send_dm(discord_user, message) - elif discord_channel := cast(str, scheduled_call.discord_channel): - logger.info(f"Broadcasting message to {discord_channel}: {message}") - discord.broadcast_message(discord_channel, message) + if discord_user := scheduled_call.discord_user: + logger.info(f"Sending DM to {discord_user.username}: {message}") + discord.send_dm(discord_user.username, message) + elif discord_channel := scheduled_call.discord_channel: + logger.info(f"Broadcasting message to {discord_channel.name}: {message}") + discord.broadcast_message(discord_channel.name, message) else: logger.warning( f"No Discord user or channel found for scheduled call {scheduled_call.id}" @@ -62,11 +62,7 @@ def execute_scheduled_call(self, scheduled_call_id: str): with make_session() as session: # Fetch the scheduled call - scheduled_call = ( - session.query(ScheduledLLMCall) - .filter(ScheduledLLMCall.id == scheduled_call_id) - .first() - ) + scheduled_call = session.query(ScheduledLLMCall).get(scheduled_call_id) if not scheduled_call: logger.error(f"Scheduled call {scheduled_call_id} not found") diff --git a/tests/conftest.py b/tests/conftest.py index 175eedf..683cb4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -254,17 +254,59 @@ def mock_openai_client(): with patch.object(openai, "OpenAI", autospec=True) as mock_client: client = mock_client() client.chat = Mock() + + # Mock non-streaming response client.chat.completions.create = Mock( return_value=Mock( choices=[ Mock( message=Mock( content="test summarytag1tag2" - ) + ), + finish_reason=None, ) ] ) ) + + # Store original side_effect for potential override + def streaming_response(*args, **kwargs): + if kwargs.get("stream"): + # Return mock streaming chunks + return iter( + [ + Mock( + choices=[ + Mock( + delta=Mock(content="test", tool_calls=None), + finish_reason=None, + ) + ] + ), + Mock( + choices=[ + Mock( + delta=Mock(content=" response", tool_calls=None), + finish_reason="stop", + ) + ] + ), + ] + ) + else: + # Return non-streaming response + return Mock( + choices=[ + Mock( + message=Mock( + content="test summarytag1tag2" + ), + finish_reason=None, + ) + ] + ) + + client.chat.completions.create.side_effect = streaming_response yield client diff --git a/tests/memory/common/llms/__init__.py b/tests/memory/common/llms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/memory/common/llms/test_anthropic_event_parsing.py b/tests/memory/common/llms/test_anthropic_event_parsing.py new file mode 100644 index 0000000..45cbdeb --- /dev/null +++ b/tests/memory/common/llms/test_anthropic_event_parsing.py @@ -0,0 +1,552 @@ +"""Comprehensive tests for Anthropic stream event parsing.""" + +import pytest +from unittest.mock import Mock + +from memory.common.llms.anthropic_provider import AnthropicProvider +from memory.common.llms.base import StreamEvent + + +@pytest.fixture +def provider(): + return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229") + + +# Content Block Start Tests + + +@pytest.mark.parametrize( + "block_type,block_attrs,expected_tool_use", + [ + ( + "tool_use", + {"id": "tool-1", "name": "search", "input": {}}, + { + "id": "tool-1", + "name": "search", + "input": {}, + "server_name": None, + "is_server_call": False, + }, + ), + ( + "mcp_tool_use", + { + "id": "mcp-1", + "name": "mcp_search", + "input": {}, + "server_name": "mcp-server", + }, + { + "id": "mcp-1", + "name": "mcp_search", + "input": {}, + "server_name": "mcp-server", + "is_server_call": True, + }, + ), + ( + "server_tool_use", + { + "id": "srv-1", + "name": "server_action", + "input": {}, + "server_name": "custom-server", + }, + { + "id": "srv-1", + "name": "server_action", + "input": {}, + "server_name": "custom-server", + "is_server_call": True, + }, + ), + ], +) +def test_content_block_start_tool_types( + provider, block_type, block_attrs, expected_tool_use +): + """Different tool types should be tracked correctly.""" + block = Mock(spec=["type"] + list(block_attrs.keys())) + block.type = block_type + for key, value in block_attrs.items(): + setattr(block, key, value) + + event = Mock(spec=["type", "content_block"]) + event.type = "content_block_start" + event.content_block = block + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + assert tool_use == expected_tool_use + + +def test_content_block_start_tool_without_input(provider): + """Tool use without input field should initialize as empty string.""" + block = Mock(spec=["type", "id", "name"]) + block.type = "tool_use" + block.id = "tool-2" + block.name = "calculate" + + event = Mock(spec=["type", "content_block"]) + event.type = "content_block_start" + event.content_block = block + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert tool_use["input"] == "" + + +def test_content_block_start_tool_result(provider): + """Tool result blocks should emit tool_result event.""" + block = Mock(spec=["tool_use_id", "content"]) + block.tool_use_id = "tool-1" + block.content = "Result content" + + event = Mock(spec=["type", "content_block"]) + event.type = "content_block_start" + event.content_block = block + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is not None + assert stream_event.type == "tool_result" + assert stream_event.data == {"id": "tool-1", "result": "Result content"} + + +@pytest.mark.parametrize( + "has_content_block,block_type", + [ + (False, None), + (True, "unknown_type"), + ], +) +def test_content_block_start_ignored_cases(provider, has_content_block, block_type): + """Events without content_block or with unknown types should be ignored.""" + event = Mock(spec=["type", "content_block"] if has_content_block else ["type"]) + event.type = "content_block_start" + + if has_content_block: + block = Mock(spec=["type"]) + block.type = block_type + event.content_block = block + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + assert tool_use is None + + +# Content Block Delta Tests + + +@pytest.mark.parametrize( + "delta_type,delta_attr,attr_value,expected_type,expected_data", + [ + ("text_delta", "text", "Hello world", "text", "Hello world"), + ("text_delta", "text", "", "text", ""), + ( + "thinking_delta", + "thinking", + "Let me think...", + "thinking", + "Let me think...", + ), + ("signature_delta", "signature", "sig-12345", "thinking", None), + ], +) +def test_content_block_delta_types( + provider, delta_type, delta_attr, attr_value, expected_type, expected_data +): + """Different delta types should emit appropriate events.""" + delta = Mock(spec=["type", delta_attr]) + delta.type = delta_type + setattr(delta, delta_attr, attr_value) + + event = Mock(spec=["type", "delta"]) + event.type = "content_block_delta" + event.delta = delta + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event.type == expected_type + if expected_type == "thinking" and delta_type == "signature_delta": + assert stream_event.signature == attr_value + else: + assert stream_event.data == expected_data + + +@pytest.mark.parametrize( + "current_tool,partial_json,expected_input", + [ + ( + {"id": "t1", "name": "search", "input": '{"query": "'}, + 'test"}', + '{"query": "test"}', + ), + ( + {"id": "t1", "name": "search", "input": '{"'}, + 'key": "value"}', + '{"key": "value"}', + ), + ( + {"id": "t1", "name": "search", "input": ""}, + '{"query": "test"}', + '{"query": "test"}', + ), + ], +) +def test_content_block_delta_input_json_accumulation( + provider, current_tool, partial_json, expected_input +): + """JSON delta should accumulate to tool input.""" + delta = Mock(spec=["type", "partial_json"]) + delta.type = "input_json_delta" + delta.partial_json = partial_json + + event = Mock(spec=["type", "delta"]) + event.type = "content_block_delta" + event.delta = delta + + stream_event, tool_use = provider._handle_stream_event(event, current_tool) + + assert stream_event is None + assert tool_use["input"] == expected_input + + +def test_content_block_delta_input_json_without_tool(provider): + """JSON delta without tool context should return None.""" + delta = Mock(spec=["type", "partial_json"]) + delta.type = "input_json_delta" + delta.partial_json = '{"key": "value"}' + + event = Mock(spec=["type", "delta"]) + event.type = "content_block_delta" + event.delta = delta + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + assert tool_use is None + + +def test_content_block_delta_input_json_with_dict_input(provider): + """JSON delta shouldn't modify if input is already a dict.""" + current_tool = {"id": "t1", "name": "search", "input": {"query": "test"}} + + delta = Mock(spec=["type", "partial_json"]) + delta.type = "input_json_delta" + delta.partial_json = ', "extra": "data"' + + event = Mock(spec=["type", "delta"]) + event.type = "content_block_delta" + event.delta = delta + + stream_event, tool_use = provider._handle_stream_event(event, current_tool) + + assert tool_use["input"] == {"query": "test"} + + +@pytest.mark.parametrize( + "has_delta,delta_type", + [ + (False, None), + (True, "unknown_delta"), + ], +) +def test_content_block_delta_ignored_cases(provider, has_delta, delta_type): + """Events without delta or with unknown types should be ignored.""" + event = Mock(spec=["type", "delta"] if has_delta else ["type"]) + event.type = "content_block_delta" + + if has_delta: + delta = Mock(spec=["type"]) + delta.type = delta_type + event.delta = delta + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + + +# Content Block Stop Tests + + +@pytest.mark.parametrize( + "input_value,has_content_block,expected_input", + [ + ("", False, {}), + (" \n\t ", False, {}), + ('{"invalid": json}', False, {}), + ('{"query": "test", "limit": 10}', False, {"query": "test", "limit": 10}), + ( + '{"filters": {"type": "user", "status": ["active", "pending"]}, "limit": 100}', + False, + { + "filters": {"type": "user", "status": ["active", "pending"]}, + "limit": 100, + }, + ), + ("", True, {"query": "test"}), + ], +) +def test_content_block_stop_tool_finalization( + provider, input_value, has_content_block, expected_input +): + """Tool stop should parse or use provided input correctly.""" + current_tool = {"id": "t1", "name": "search", "input": input_value} + + event = Mock(spec=["type", "content_block"] if has_content_block else ["type"]) + event.type = "content_block_stop" + + if has_content_block: + block = Mock(spec=["input"]) + block.input = {"query": "test"} + event.content_block = block + + stream_event, tool_use = provider._handle_stream_event(event, current_tool) + + assert stream_event.type == "tool_use" + assert stream_event.data["input"] == expected_input + assert tool_use is None + + +def test_content_block_stop_with_server_info(provider): + """Server tool info should be included in final event.""" + current_tool = { + "id": "t1", + "name": "mcp_search", + "input": '{"q": "test"}', + "server_name": "mcp-server", + "is_server_call": True, + } + + event = Mock(spec=["type"]) + event.type = "content_block_stop" + + stream_event, tool_use = provider._handle_stream_event(event, current_tool) + + assert stream_event.data["server_name"] == "mcp-server" + assert stream_event.data["is_server_call"] is True + + +def test_content_block_stop_without_tool(provider): + """Stop without current tool should return None.""" + event = Mock(spec=["type"]) + event.type = "content_block_stop" + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + assert tool_use is None + + +# Message Delta Tests + + +def test_message_delta_max_tokens(provider): + """Max tokens stop reason should emit error.""" + delta = Mock(spec=["stop_reason"]) + delta.stop_reason = "max_tokens" + + event = Mock(spec=["type", "delta"]) + event.type = "message_delta" + event.delta = delta + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event.type == "error" + assert "Max tokens" in stream_event.data + + +@pytest.mark.parametrize("stop_reason", ["end_turn", "stop_sequence", None]) +def test_message_delta_other_stop_reasons(provider, stop_reason): + """Other stop reasons should not emit error.""" + delta = Mock(spec=["stop_reason"]) + delta.stop_reason = stop_reason + + event = Mock(spec=["type", "delta"]) + event.type = "message_delta" + event.delta = delta + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + + +def test_message_delta_token_usage(provider): + """Token usage should be logged but not emitted.""" + usage = Mock( + spec=[ + "input_tokens", + "output_tokens", + "cache_creation_input_tokens", + "cache_read_input_tokens", + ] + ) + usage.input_tokens = 100 + usage.output_tokens = 50 + usage.cache_creation_input_tokens = 10 + usage.cache_read_input_tokens = 20 + + event = Mock(spec=["type", "usage"]) + event.type = "message_delta" + event.usage = usage + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + + +def test_message_delta_empty(provider): + """Message delta without delta or usage should return None.""" + event = Mock(spec=["type"]) + event.type = "message_delta" + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + + +# Message Stop Tests + + +@pytest.mark.parametrize( + "current_tool", + [ + None, + {"id": "t1", "name": "search", "input": '{"incomplete'}, + ], +) +def test_message_stop(provider, current_tool): + """Message stop should emit done regardless of incomplete tools.""" + event = Mock(spec=["type"]) + event.type = "message_stop" + + stream_event, tool_use = provider._handle_stream_event(event, current_tool) + + assert stream_event.type == "done" + assert tool_use is None + + +# Error Handling Tests + + +@pytest.mark.parametrize( + "has_error,error_value,expected_message", + [ + (True, "API rate limit exceeded", "rate limit"), + (False, None, "Unknown error"), + ], +) +def test_error_events(provider, has_error, error_value, expected_message): + """Error events should emit error StreamEvent.""" + event = Mock(spec=["type", "error"] if has_error else ["type"]) + event.type = "error" + if has_error: + event.error = error_value + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event.type == "error" + assert expected_message in stream_event.data + + +# Unknown Event Tests + + +@pytest.mark.parametrize( + "event_type", + ["message_start", "future_event_type", None], +) +def test_unknown_or_ignored_events(provider, event_type): + """Unknown event types should be logged but not fail.""" + if event_type is None: + event = Mock(spec=[]) + else: + event = Mock(spec=["type"]) + event.type = event_type + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + + +# State Transition Tests + + +def test_complete_tool_call_sequence(provider): + """Simulate a complete tool call from start to finish.""" + # Start + block = Mock(spec=["type", "id", "name", "input"]) + block.type = "tool_use" + block.id = "tool-1" + block.name = "search" + block.input = None + + event1 = Mock(spec=["type", "content_block"]) + event1.type = "content_block_start" + event1.content_block = block + + _, tool_use = provider._handle_stream_event(event1, None) + assert tool_use["input"] == "" + + # Delta 1 + delta1 = Mock(spec=["type", "partial_json"]) + delta1.type = "input_json_delta" + delta1.partial_json = '{"query":' + + event2 = Mock(spec=["type", "delta"]) + event2.type = "content_block_delta" + event2.delta = delta1 + + _, tool_use = provider._handle_stream_event(event2, tool_use) + assert tool_use["input"] == '{"query":' + + # Delta 2 + delta2 = Mock(spec=["type", "partial_json"]) + delta2.type = "input_json_delta" + delta2.partial_json = ' "test"}' + + event3 = Mock(spec=["type", "delta"]) + event3.type = "content_block_delta" + event3.delta = delta2 + + _, tool_use = provider._handle_stream_event(event3, tool_use) + assert tool_use["input"] == '{"query": "test"}' + + # Stop + event4 = Mock(spec=["type"]) + event4.type = "content_block_stop" + + stream_event, tool_use = provider._handle_stream_event(event4, tool_use) + + assert stream_event.type == "tool_use" + assert stream_event.data["input"] == {"query": "test"} + assert tool_use is None + + +def test_text_and_thinking_mixed(provider): + """Text and thinking deltas should be handled independently.""" + delta1 = Mock(spec=["type", "text"]) + delta1.type = "text_delta" + delta1.text = "Answer: " + + event1 = Mock(spec=["type", "delta"]) + event1.type = "content_block_delta" + event1.delta = delta1 + + event1_result, _ = provider._handle_stream_event(event1, None) + assert event1_result.type == "text" + + delta2 = Mock(spec=["type", "thinking"]) + delta2.type = "thinking_delta" + delta2.thinking = "reasoning..." + + event2 = Mock(spec=["type", "delta"]) + event2.type = "content_block_delta" + event2.delta = delta2 + + event2_result, _ = provider._handle_stream_event(event2, None) + assert event2_result.type == "thinking" diff --git a/tests/memory/common/llms/test_anthropic_provider.py b/tests/memory/common/llms/test_anthropic_provider.py new file mode 100644 index 0000000..6c76863 --- /dev/null +++ b/tests/memory/common/llms/test_anthropic_provider.py @@ -0,0 +1,440 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from PIL import Image + +from memory.common.llms.anthropic_provider import AnthropicProvider +from memory.common.llms.base import ( + Message, + MessageRole, + TextContent, + ImageContent, + ToolUseContent, + ToolResultContent, + ThinkingContent, + LLMSettings, + StreamEvent, +) +from memory.common.llms.tools import ToolDefinition + + +@pytest.fixture +def provider(): + return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229") + + +@pytest.fixture +def thinking_provider(): + return AnthropicProvider( + api_key="test-key", model="claude-opus-4", enable_thinking=True + ) + + +def test_initialization(provider): + assert provider.api_key == "test-key" + assert provider.model == "claude-3-opus-20240229" + assert provider.enable_thinking is False + + +def test_client_lazy_loading(provider): + assert provider._client is None + client = provider.client + assert client is not None + assert provider._client is not None + # Second call should return same instance + assert provider.client is client + + +def test_async_client_lazy_loading(provider): + assert provider._async_client is None + client = provider.async_client + assert client is not None + assert provider._async_client is not None + + +@pytest.mark.parametrize( + "model, expected", + [ + ("claude-opus-4", True), + ("claude-opus-4-1", True), + ("claude-sonnet-4-0", True), + ("claude-sonnet-3-7", True), + ("claude-sonnet-4-5", True), + ("claude-3-opus-20240229", False), + ("claude-3-sonnet-20240229", False), + ("gpt-4", False), + ], +) +def test_supports_thinking(model, expected): + provider = AnthropicProvider(api_key="test-key", model=model) + assert provider._supports_thinking() == expected + + +def test_convert_text_content(provider): + content = TextContent(text="hello world") + result = provider._convert_text_content(content) + assert result == {"type": "text", "text": "hello world"} + + +def test_convert_image_content(provider): + image = Image.new("RGB", (100, 100), color="red") + content = ImageContent(image=image) + result = provider._convert_image_content(content) + + assert result["type"] == "image" + assert result["source"]["type"] == "base64" + assert result["source"]["media_type"] == "image/jpeg" + assert isinstance(result["source"]["data"], str) + + +def test_should_include_message_filters_system(provider): + system_msg = Message(role=MessageRole.SYSTEM, content="system prompt") + user_msg = Message(role=MessageRole.USER, content="user message") + + assert provider._should_include_message(system_msg) is False + assert provider._should_include_message(user_msg) is True + + +@pytest.mark.parametrize( + "messages, expected_count", + [ + ([Message(role=MessageRole.USER, content="test")], 1), + ([Message(role=MessageRole.SYSTEM, content="test")], 0), + ( + [ + Message(role=MessageRole.SYSTEM, content="system"), + Message(role=MessageRole.USER, content="user"), + ], + 1, + ), + ], +) +def test_convert_messages(provider, messages, expected_count): + result = provider._convert_messages(messages) + assert len(result) == expected_count + + +def test_convert_tool(provider): + tool = ToolDefinition( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + function=lambda x: "result", + ) + result = provider._convert_tool(tool) + + assert result["name"] == "test_tool" + assert result["description"] == "A test tool" + assert result["input_schema"] == {"type": "object", "properties": {}} + + +def test_build_request_kwargs_basic(provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings(temperature=0.5, max_tokens=1000) + + kwargs = provider._build_request_kwargs(messages, None, None, settings) + + assert kwargs["model"] == "claude-3-opus-20240229" + assert kwargs["temperature"] == 0.5 + assert kwargs["max_tokens"] == 1000 + assert len(kwargs["messages"]) == 1 + + +def test_build_request_kwargs_with_system_prompt(provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings() + + kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings) + + assert kwargs["system"] == "system prompt" + + +def test_build_request_kwargs_with_tools(provider): + messages = [Message(role=MessageRole.USER, content="test")] + tools = [ + ToolDefinition( + name="test", + description="test", + input_schema={}, + function=lambda x: "result", + ) + ] + settings = LLMSettings() + + kwargs = provider._build_request_kwargs(messages, None, tools, settings) + + assert "tools" in kwargs + assert len(kwargs["tools"]) == 1 + + +def test_build_request_kwargs_with_thinking(thinking_provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings(max_tokens=5000) + + kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings) + + assert "thinking" in kwargs + assert kwargs["thinking"]["type"] == "enabled" + assert kwargs["thinking"]["budget_tokens"] == 3976 + assert kwargs["temperature"] == 1.0 + assert "top_p" not in kwargs + + +def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings(max_tokens=1000) + + kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings) + + # Shouldn't enable thinking if not enough tokens + assert "thinking" not in kwargs + + +def test_handle_stream_event_text_delta(provider): + event = Mock( + type="content_block_delta", + delta=Mock(type="text_delta", text="hello"), + ) + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is not None + assert stream_event.type == "text" + assert stream_event.data == "hello" + assert tool_use is None + + +def test_handle_stream_event_thinking_delta(provider): + event = Mock( + type="content_block_delta", + delta=Mock(type="thinking_delta", thinking="reasoning..."), + ) + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is not None + assert stream_event.type == "thinking" + assert stream_event.data == "reasoning..." + + +def test_handle_stream_event_tool_use_start(provider): + block = Mock(spec=["type", "id", "name", "input"]) + block.type = "tool_use" + block.id = "tool-1" + block.name = "test_tool" + block.input = {} + + event = Mock(spec=["type", "content_block"]) + event.type = "content_block_start" + event.content_block = block + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is None + assert tool_use is not None + assert tool_use["id"] == "tool-1" + assert tool_use["name"] == "test_tool" + assert tool_use["input"] == {} + + +def test_handle_stream_event_tool_input_delta(provider): + current_tool = {"id": "tool-1", "name": "test", "input": '{"ke'} + event = Mock( + type="content_block_delta", + delta=Mock(type="input_json_delta", partial_json='y": "val'), + ) + + stream_event, tool_use = provider._handle_stream_event(event, current_tool) + + assert stream_event is None + assert tool_use["input"] == '{"key": "val' + + +def test_handle_stream_event_tool_use_complete(provider): + current_tool = { + "id": "tool-1", + "name": "test_tool", + "input": '{"key": "value"}', + } + event = Mock( + type="content_block_stop", + content_block=Mock(input={"key": "value"}), + ) + + stream_event, tool_use = provider._handle_stream_event(event, current_tool) + + assert stream_event is not None + assert stream_event.type == "tool_use" + assert stream_event.data["id"] == "tool-1" + assert stream_event.data["name"] == "test_tool" + assert stream_event.data["input"] == {"key": "value"} + assert tool_use is None + + +def test_handle_stream_event_message_stop(provider): + event = Mock(type="message_stop") + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is not None + assert stream_event.type == "done" + assert tool_use is None + + +def test_handle_stream_event_error(provider): + event = Mock(type="error", error="API error") + + stream_event, tool_use = provider._handle_stream_event(event, None) + + assert stream_event is not None + assert stream_event.type == "error" + assert "API error" in stream_event.data + + +def test_generate_basic(provider, mock_anthropic_client): + messages = [Message(role=MessageRole.USER, content="test")] + + # Mock the response properly + mock_block = Mock(spec=["type", "text"]) + mock_block.type = "text" + mock_block.text = "test summary" + + mock_response = Mock(spec=["content"]) + mock_response.content = [mock_block] + + provider.client.messages.create.return_value = mock_response + + result = provider.generate(messages) + + assert result == "test summary" + provider.client.messages.create.assert_called_once() + + +def test_stream_basic(provider, mock_anthropic_client): + messages = [Message(role=MessageRole.USER, content="test")] + + events = list(provider.stream(messages)) + + # Should get text event and done event + assert len(events) > 0 + assert any(e.type == "text" for e in events) + provider.client.messages.stream.assert_called_once() + + +@pytest.mark.asyncio +async def test_agenerate_basic(provider, mock_anthropic_client): + messages = [Message(role=MessageRole.USER, content="test")] + + result = await provider.agenerate(messages) + + assert result == "test summary" + provider.async_client.messages.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_astream_basic(provider, mock_anthropic_client): + messages = [Message(role=MessageRole.USER, content="test")] + + events = [] + async for event in provider.astream(messages): + events.append(event) + + assert len(events) > 0 + assert any(e.type == "text" for e in events) + + +def test_convert_message_sorts_thinking_content(provider): + """Thinking content should be sorted so non-thinking comes before thinking.""" + message = Message.assistant( + ThinkingContent(thinking="reasoning", signature="sig"), + TextContent(text="response"), + ) + + result = provider._convert_message(message) + + assert result["role"] == "assistant" + # The sort key (x["type"] != "thinking") sorts thinking type to beginning + # because "thinking" != "thinking" is False, which sorts before True + content_types = [c["type"] for c in result["content"]] + assert "text" in content_types + assert "thinking" in content_types + # Verify thinking comes before non-thinking (sorted by key) + thinking_idx = content_types.index("thinking") + text_idx = content_types.index("text") + assert thinking_idx < text_idx + + +def test_execute_tool_success(provider): + tool_call = {"id": "t1", "name": "test", "input": {"arg": "value"}} + tools = { + "test": ToolDefinition( + name="test", + description="test", + input_schema={}, + function=lambda x: f"result: {x['arg']}", + ) + } + + result = provider.execute_tool(tool_call, tools) + + assert result.tool_use_id == "t1" + assert result.content == "result: value" + assert result.is_error is False + + +def test_execute_tool_missing_name(provider): + tool_call = {"id": "t1", "input": {}} + tools = {} + + result = provider.execute_tool(tool_call, tools) + + assert result.tool_use_id == "t1" + assert "missing" in result.content.lower() + assert result.is_error is True + + +def test_execute_tool_not_found(provider): + tool_call = {"id": "t1", "name": "nonexistent", "input": {}} + tools = {} + + result = provider.execute_tool(tool_call, tools) + + assert result.tool_use_id == "t1" + assert "not found" in result.content.lower() + assert result.is_error is True + + +def test_execute_tool_exception(provider): + tool_call = {"id": "t1", "name": "test", "input": {}} + tools = { + "test": ToolDefinition( + name="test", + description="test", + input_schema={}, + function=lambda x: 1 / 0, # Raises ZeroDivisionError + ) + } + + result = provider.execute_tool(tool_call, tools) + + assert result.tool_use_id == "t1" + assert result.is_error is True + assert "division" in result.content.lower() + + +def test_encode_image(provider): + image = Image.new("RGB", (10, 10), color="blue") + + encoded = provider.encode_image(image) + + assert isinstance(encoded, str) + assert len(encoded) > 0 + + +def test_encode_image_rgba(provider): + """RGBA images should be converted to RGB.""" + image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128)) + + encoded = provider.encode_image(image) + + assert isinstance(encoded, str) + assert len(encoded) > 0 diff --git a/tests/memory/common/llms/test_base.py b/tests/memory/common/llms/test_base.py new file mode 100644 index 0000000..d0663d2 --- /dev/null +++ b/tests/memory/common/llms/test_base.py @@ -0,0 +1,270 @@ +import pytest +from PIL import Image + +from memory.common.llms.base import ( + Message, + MessageRole, + TextContent, + ImageContent, + ToolUseContent, + ToolResultContent, + ThinkingContent, + LLMSettings, + StreamEvent, + create_provider, +) +from memory.common.llms.anthropic_provider import AnthropicProvider +from memory.common.llms.openai_provider import OpenAIProvider +from memory.common import settings + + +def test_message_role_enum(): + assert MessageRole.SYSTEM == "system" + assert MessageRole.USER == "user" + assert MessageRole.ASSISTANT == "assistant" + assert MessageRole.TOOL == "tool" + + +def test_text_content_creation(): + content = TextContent(text="hello") + assert content.type == "text" + assert content.text == "hello" + assert content.valid + + +def test_text_content_to_dict(): + content = TextContent(text="hello") + result = content.to_dict() + assert result == {"type": "text", "text": "hello"} + + +def test_text_content_empty_invalid(): + content = TextContent(text="") + assert not content.valid + + +def test_image_content_creation(): + image = Image.new("RGB", (10, 10)) + content = ImageContent(image=image) + assert content.type == "image" + assert content.image == image + assert content.valid + + +def test_image_content_with_detail(): + image = Image.new("RGB", (10, 10)) + content = ImageContent(image=image, detail="high") + assert content.detail == "high" + + +def test_tool_use_content_creation(): + content = ToolUseContent(id="t1", name="test_tool", input={"arg": "value"}) + assert content.type == "tool_use" + assert content.id == "t1" + assert content.name == "test_tool" + assert content.input == {"arg": "value"} + assert content.valid + + +def test_tool_use_content_to_dict(): + content = ToolUseContent(id="t1", name="test", input={"key": "val"}) + result = content.to_dict() + assert result == { + "type": "tool_use", + "id": "t1", + "name": "test", + "input": {"key": "val"}, + } + + +def test_tool_result_content_creation(): + content = ToolResultContent( + tool_use_id="t1", + content="result", + is_error=False, + ) + assert content.type == "tool_result" + assert content.tool_use_id == "t1" + assert content.content == "result" + assert not content.is_error + assert content.valid + + +def test_tool_result_content_with_error(): + content = ToolResultContent( + tool_use_id="t1", + content="error message", + is_error=True, + ) + assert content.is_error + + +def test_thinking_content_creation(): + content = ThinkingContent(thinking="reasoning...", signature="sig") + assert content.type == "thinking" + assert content.thinking == "reasoning..." + assert content.signature == "sig" + assert content.valid + + +def test_thinking_content_invalid_without_signature(): + content = ThinkingContent(thinking="reasoning...") + assert not content.valid + + +def test_message_simple_string_content(): + msg = Message(role=MessageRole.USER, content="hello") + assert msg.role == MessageRole.USER + assert msg.content == "hello" + + +def test_message_list_content(): + content_list = [TextContent(text="hello"), TextContent(text="world")] + msg = Message(role=MessageRole.USER, content=content_list) + assert msg.role == MessageRole.USER + assert len(msg.content) == 2 + + +def test_message_to_dict_string(): + msg = Message(role=MessageRole.USER, content="hello") + result = msg.to_dict() + assert result == {"role": "user", "content": "hello"} + + +def test_message_to_dict_list(): + msg = Message( + role=MessageRole.USER, + content=[TextContent(text="hello"), TextContent(text="world")], + ) + result = msg.to_dict() + assert result["role"] == "user" + assert len(result["content"]) == 2 + assert result["content"][0] == {"type": "text", "text": "hello"} + + +def test_message_assistant_factory(): + msg = Message.assistant( + TextContent(text="response"), + ToolUseContent(id="t1", name="tool", input={}), + ) + assert msg.role == MessageRole.ASSISTANT + assert len(msg.content) == 2 + + +def test_message_assistant_filters_invalid_content(): + msg = Message.assistant( + TextContent(text="valid"), + TextContent(text=""), # Invalid - empty + ) + assert len(msg.content) == 1 + assert msg.content[0].text == "valid" + + +def test_message_user_factory(): + msg = Message.user(text="hello") + assert msg.role == MessageRole.USER + assert len(msg.content) == 1 + assert isinstance(msg.content[0], TextContent) + + +def test_message_user_with_tool_result(): + tool_result = ToolResultContent(tool_use_id="t1", content="result") + msg = Message.user(text="hello", tool_result=tool_result) + assert len(msg.content) == 2 + + +def test_stream_event_creation(): + event = StreamEvent(type="text", data="hello") + assert event.type == "text" + assert event.data == "hello" + + +def test_stream_event_with_signature(): + event = StreamEvent(type="thinking", signature="sig123") + assert event.signature == "sig123" + + +def test_llm_settings_defaults(): + settings = LLMSettings() + assert settings.temperature == 0.7 + assert settings.max_tokens == 2048 + assert settings.top_p is None + assert settings.stop_sequences is None + assert settings.stream is False + + +def test_llm_settings_custom(): + settings = LLMSettings( + temperature=0.5, + max_tokens=1000, + top_p=0.9, + stop_sequences=["STOP"], + stream=True, + ) + assert settings.temperature == 0.5 + assert settings.max_tokens == 1000 + assert settings.top_p == 0.9 + assert settings.stop_sequences == ["STOP"] + assert settings.stream is True + + +def test_create_provider_anthropic(): + provider = create_provider( + model="anthropic/claude-3-opus-20240229", + api_key="test-key", + ) + assert isinstance(provider, AnthropicProvider) + assert provider.model == "claude-3-opus-20240229" + + +def test_create_provider_openai(): + provider = create_provider( + model="openai/gpt-4o", + api_key="test-key", + ) + assert isinstance(provider, OpenAIProvider) + assert provider.model == "gpt-4o" + + +def test_create_provider_unknown_raises(): + with pytest.raises(ValueError, match="Unknown provider"): + create_provider(model="unknown/model", api_key="test-key") + + +def test_create_provider_uses_default_model(): + """If no model provided, should use SUMMARIZER_MODEL from settings.""" + provider = create_provider(api_key="test-key") + # Should create a provider (type depends on settings.SUMMARIZER_MODEL) + assert provider is not None + + +def test_create_provider_anthropic_with_thinking(): + provider = create_provider( + model="anthropic/claude-opus-4", + api_key="test-key", + enable_thinking=True, + ) + assert isinstance(provider, AnthropicProvider) + assert provider.enable_thinking is True + + +def test_create_provider_missing_anthropic_key(): + # Temporarily clear the API key from settings + original_key = settings.ANTHROPIC_API_KEY + try: + settings.ANTHROPIC_API_KEY = "" + with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"): + create_provider(model="anthropic/claude-3-opus-20240229") + finally: + settings.ANTHROPIC_API_KEY = original_key + + +def test_create_provider_missing_openai_key(): + # Temporarily clear the API key from settings + original_key = settings.OPENAI_API_KEY + try: + settings.OPENAI_API_KEY = "" + with pytest.raises(ValueError, match="OPENAI_API_KEY"): + create_provider(model="openai/gpt-4o") + finally: + settings.OPENAI_API_KEY = original_key diff --git a/tests/memory/common/llms/test_openai_event_parsing.py b/tests/memory/common/llms/test_openai_event_parsing.py new file mode 100644 index 0000000..2574715 --- /dev/null +++ b/tests/memory/common/llms/test_openai_event_parsing.py @@ -0,0 +1,478 @@ +"""Comprehensive tests for OpenAI stream chunk parsing.""" + +import pytest +from unittest.mock import Mock + +from memory.common.llms.openai_provider import OpenAIProvider +from memory.common.llms.base import StreamEvent + + +@pytest.fixture +def provider(): + return OpenAIProvider(api_key="test-key", model="gpt-4o") + + +# Text Content Tests + + +@pytest.mark.parametrize( + "content,expected_events", + [ + ("Hello", 1), + ("", 0), # Empty string is falsy + (None, 0), + ("Line 1\nLine 2\nLine 3", 1), + ("Hello δΈ–η•Œ 🌍", 1), + ], +) +def test_text_content(provider, content, expected_events): + """Text content should emit text events appropriately.""" + delta = Mock(spec=["content", "tool_calls"]) + delta.content = content + delta.tool_calls = None + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == expected_events + if expected_events > 0: + assert events[0].type == "text" + assert events[0].data == content + assert tool_call is None + + +# Tool Call Start Tests + + +def test_new_tool_call_basic(provider): + """New tool call should initialize state.""" + function = Mock(spec=["name", "arguments"]) + function.name = "search" + function.arguments = "" + + tool = Mock(spec=["id", "function"]) + tool.id = "call_123" + tool.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = [tool] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 0 + assert tool_call == {"id": "call_123", "name": "search", "arguments": ""} + + +@pytest.mark.parametrize( + "name,arguments,expected_name,expected_args", + [ + ("calculate", '{"operation":', "calculate", '{"operation":'), + (None, "", "", ""), + ("test", None, "test", ""), + ], +) +def test_new_tool_call_variations( + provider, name, arguments, expected_name, expected_args +): + """Tool calls with various name/argument combinations.""" + function = Mock(spec=["name", "arguments"]) + function.name = name + function.arguments = arguments + + tool = Mock(spec=["id", "function"]) + tool.id = "call_123" + tool.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = [tool] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert tool_call["name"] == expected_name + assert tool_call["arguments"] == expected_args + + +def test_new_tool_call_replaces_previous(provider): + """New tool call should finalize and replace previous.""" + current = {"id": "call_old", "name": "old_tool", "arguments": '{"arg": "value"}'} + + function = Mock(spec=["name", "arguments"]) + function.name = "new_tool" + function.arguments = "" + + tool = Mock(spec=["id", "function"]) + tool.id = "call_new" + tool.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = [tool] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, current) + + assert len(events) == 1 + assert events[0].type == "tool_use" + assert events[0].data["id"] == "call_old" + assert events[0].data["input"] == {"arg": "value"} + assert tool_call["id"] == "call_new" + + +# Tool Call Continuation Tests + + +@pytest.mark.parametrize( + "initial_args,new_args,expected_args", + [ + ('{"query": "', 'test query"}', '{"query": "test query"}'), + ('{"query"', ': "value"}', '{"query": "value"}'), + ("", '{"full": "json"}', '{"full": "json"}'), + ('{"partial"', "", '{"partial"'), # Empty doesn't accumulate + ], +) +def test_tool_call_argument_accumulation( + provider, initial_args, new_args, expected_args +): + """Arguments should accumulate correctly.""" + current = {"id": "call_123", "name": "search", "arguments": initial_args} + + function = Mock(spec=["name", "arguments"]) + function.name = None + function.arguments = new_args + + tool = Mock(spec=["id", "function"]) + tool.id = None + tool.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = [tool] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, current) + + assert len(events) == 0 + assert tool_call["arguments"] == expected_args + + +def test_tool_call_accumulation_without_current_tool(provider): + """Arguments without current tool should be ignored.""" + function = Mock(spec=["name", "arguments"]) + function.name = None + function.arguments = '{"arg": "value"}' + + tool = Mock(spec=["id", "function"]) + tool.id = None + tool.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = [tool] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 0 + assert tool_call is None + + +def test_incremental_json_building(provider): + """Test realistic incremental JSON building across multiple chunks.""" + current = {"id": "c1", "name": "search", "arguments": ""} + + increments = ['{"', 'query":', ' "test"}'] + expected_states = ['{"', '{"query":', '{"query": "test"}'] + + for increment, expected in zip(increments, expected_states): + function = Mock(spec=["name", "arguments"]) + function.name = None + function.arguments = increment + + tool = Mock(spec=["id", "function"]) + tool.id = None + tool.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = [tool] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + _, current = provider._handle_stream_chunk(chunk, current) + assert current["arguments"] == expected + + +# Finish Reason Tests + + +def test_finish_reason_without_tool(provider): + """Stop finish without tool should not emit events.""" + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = None + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = "stop" + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 0 + assert tool_call is None + + +@pytest.mark.parametrize( + "arguments,expected_input", + [ + ('{"query": "test"}', {"query": "test"}), + ('{"invalid": json}', {}), + ("", {}), + ], +) +def test_finish_reason_with_tool(provider, arguments, expected_input): + """Finish with tool call should finalize and emit.""" + current = {"id": "call_123", "name": "search", "arguments": arguments} + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = None + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = "tool_calls" + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, current) + + assert len(events) == 1 + assert events[0].type == "tool_use" + assert events[0].data["id"] == "call_123" + assert events[0].data["input"] == expected_input + assert tool_call is None + + +@pytest.mark.parametrize("reason", ["stop", "length", "content_filter", "tool_calls"]) +def test_various_finish_reasons(provider, reason): + """Various finish reasons with active tool should finalize.""" + current = {"id": "call_123", "name": "test", "arguments": '{"a": 1}'} + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = None + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = reason + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, current) + + assert len(events) == 1 + assert tool_call is None + + +# Edge Cases Tests + + +def test_empty_choices(provider): + """Empty choices list should return empty events.""" + chunk = Mock(spec=["choices"]) + chunk.choices = [] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 0 + assert tool_call is None + + +def test_none_choices(provider): + """None choices should be handled gracefully.""" + chunk = Mock(spec=["choices"]) + chunk.choices = None + + try: + events, tool_call = provider._handle_stream_chunk(chunk, None) + assert len(events) == 0 + except (TypeError, AttributeError): + pass # Also acceptable for malformed input + + +def test_multiple_chunks_in_sequence(provider): + """Test processing multiple chunks sequentially.""" + # Chunk 1: Start + function1 = Mock(spec=["name", "arguments"]) + function1.name = "search" + function1.arguments = "" + + tool1 = Mock(spec=["id", "function"]) + tool1.id = "call_1" + tool1.function = function1 + + delta1 = Mock(spec=["content", "tool_calls"]) + delta1.content = None + delta1.tool_calls = [tool1] + + choice1 = Mock(spec=["delta", "finish_reason"]) + choice1.delta = delta1 + choice1.finish_reason = None + + chunk1 = Mock(spec=["choices"]) + chunk1.choices = [choice1] + + events1, state = provider._handle_stream_chunk(chunk1, None) + assert len(events1) == 0 + assert state is not None + + # Chunk 2: Args + function2 = Mock(spec=["name", "arguments"]) + function2.name = None + function2.arguments = '{"q": "test"}' + + tool2 = Mock(spec=["id", "function"]) + tool2.id = None + tool2.function = function2 + + delta2 = Mock(spec=["content", "tool_calls"]) + delta2.content = None + delta2.tool_calls = [tool2] + + choice2 = Mock(spec=["delta", "finish_reason"]) + choice2.delta = delta2 + choice2.finish_reason = None + + chunk2 = Mock(spec=["choices"]) + chunk2.choices = [choice2] + + events2, state = provider._handle_stream_chunk(chunk2, state) + assert len(events2) == 0 + assert state["arguments"] == '{"q": "test"}' + + # Chunk 3: Finish + delta3 = Mock(spec=["content", "tool_calls"]) + delta3.content = None + delta3.tool_calls = None + + choice3 = Mock(spec=["delta", "finish_reason"]) + choice3.delta = delta3 + choice3.finish_reason = "stop" + + chunk3 = Mock(spec=["choices"]) + chunk3.choices = [choice3] + + events3, state = provider._handle_stream_chunk(chunk3, state) + assert len(events3) == 1 + assert events3[0].type == "tool_use" + assert state is None + + +def test_text_and_tool_calls_mixed(provider): + """Text content should be emitted before tool initialization.""" + function = Mock(spec=["name", "arguments"]) + function.name = "search" + function.arguments = "" + + tool = Mock(spec=["id", "function"]) + tool.id = "call_1" + tool.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = "Let me search for that." + delta.tool_calls = [tool] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 1 + assert events[0].type == "text" + assert events[0].data == "Let me search for that." + assert tool_call is not None + + +# JSON Parsing Tests + + +@pytest.mark.parametrize( + "arguments,expected_input", + [ + ('{"key": "value", "num": 42}', {"key": "value", "num": 42}), + ("{}", {}), + ( + '{"user": {"name": "John", "tags": ["a", "b"]}, "count": 10}', + {"user": {"name": "John", "tags": ["a", "b"]}, "count": 10}, + ), + ('{"invalid": json}', {}), + ('{"key": "val', {}), + ("", {}), + ('{"text": "Hello δΈ–η•Œ 🌍"}', {"text": "Hello δΈ–η•Œ 🌍"}), + ( + '{"text": "Line 1\\nLine 2\\t\\tTabbed"}', + {"text": "Line 1\nLine 2\t\tTabbed"}, + ), + ], +) +def test_json_parsing(provider, arguments, expected_input): + """Various JSON inputs should be parsed correctly.""" + tool_call = {"id": "c1", "name": "test", "arguments": arguments} + + result = provider._parse_and_finalize_tool_call(tool_call) + + assert result["input"] == expected_input + assert "arguments" not in result diff --git a/tests/memory/common/llms/test_openai_provider.py b/tests/memory/common/llms/test_openai_provider.py new file mode 100644 index 0000000..896aa96 --- /dev/null +++ b/tests/memory/common/llms/test_openai_provider.py @@ -0,0 +1,561 @@ +import pytest +from unittest.mock import Mock +from PIL import Image + +from memory.common.llms.openai_provider import OpenAIProvider +from memory.common.llms.base import ( + Message, + MessageRole, + TextContent, + ImageContent, + ToolUseContent, + ToolResultContent, + LLMSettings, + StreamEvent, +) +from memory.common.llms.tools import ToolDefinition + + +@pytest.fixture +def provider(): + return OpenAIProvider(api_key="test-key", model="gpt-4o") + + +@pytest.fixture +def reasoning_provider(): + return OpenAIProvider(api_key="test-key", model="o1-preview") + + +def test_initialization(provider): + assert provider.api_key == "test-key" + assert provider.model == "gpt-4o" + + +def test_client_lazy_loading(provider): + assert provider._client is None + client = provider.client + assert client is not None + assert provider._client is not None + + +def test_async_client_lazy_loading(provider): + assert provider._async_client is None + client = provider.async_client + assert client is not None + assert provider._async_client is not None + + +@pytest.mark.parametrize( + "model, expected", + [ + ("gpt-4o", False), + ("o1-preview", True), + ("o1-mini", True), + ("gpt-4-turbo", True), + ("gpt-3.5-turbo", True), + ], +) +def test_is_reasoning_model(model, expected): + provider = OpenAIProvider(api_key="test-key", model=model) + assert provider._is_reasoning_model() == expected + + +def test_convert_text_content(provider): + content = TextContent(text="hello world") + result = provider._convert_text_content(content) + assert result == {"type": "text", "text": "hello world"} + + +def test_convert_image_content(provider): + image = Image.new("RGB", (100, 100), color="red") + content = ImageContent(image=image) + result = provider._convert_image_content(content) + + assert result["type"] == "image_url" + assert "image_url" in result + assert result["image_url"]["url"].startswith("data:image/jpeg;base64,") + + +def test_convert_image_content_with_detail(provider): + image = Image.new("RGB", (100, 100), color="red") + content = ImageContent(image=image, detail="high") + result = provider._convert_image_content(content) + + assert result["image_url"]["detail"] == "high" + + +def test_convert_tool_use_content(provider): + content = ToolUseContent( + id="t1", + name="test_tool", + input={"arg": "value"}, + ) + result = provider._convert_tool_use_content(content) + + assert result["id"] == "t1" + assert result["type"] == "function" + assert result["function"]["name"] == "test_tool" + assert '{"arg": "value"}' in result["function"]["arguments"] + + +def test_convert_tool_result_content(provider): + content = ToolResultContent( + tool_use_id="t1", + content="result content", + is_error=False, + ) + result = provider._convert_tool_result_content(content) + + assert result["role"] == "tool" + assert result["tool_call_id"] == "t1" + assert result["content"] == "result content" + + +def test_convert_messages_simple(provider): + messages = [Message(role=MessageRole.USER, content="test")] + result = provider._convert_messages(messages) + + assert len(result) == 1 + assert result[0]["role"] == "user" + assert result[0]["content"] == "test" + + +def test_convert_messages_with_tool_result(provider): + """Tool results should become separate messages with 'tool' role.""" + messages = [ + Message( + role=MessageRole.USER, + content=[ToolResultContent(tool_use_id="t1", content="result")], + ) + ] + result = provider._convert_messages(messages) + + assert len(result) == 1 + assert result[0]["role"] == "tool" + assert result[0]["tool_call_id"] == "t1" + + +def test_convert_messages_with_tool_use(provider): + """Tool use content should become tool_calls field.""" + messages = [ + Message.assistant( + TextContent(text="thinking..."), + ToolUseContent(id="t1", name="test", input={}), + ) + ] + result = provider._convert_messages(messages) + + assert len(result) == 1 + assert result[0]["role"] == "assistant" + assert "tool_calls" in result[0] + assert len(result[0]["tool_calls"]) == 1 + + +def test_convert_messages_mixed_content(provider): + """Messages with both text and tool results should be split.""" + messages = [ + Message( + role=MessageRole.USER, + content=[ + TextContent(text="user text"), + ToolResultContent(tool_use_id="t1", content="result"), + ], + ) + ] + result = provider._convert_messages(messages) + + # Should create two messages: one user message and one tool message + assert len(result) == 2 + assert result[0]["role"] == "tool" + assert result[1]["role"] == "user" + + +def test_convert_tools(provider): + tools = [ + ToolDefinition( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {"arg": {"type": "string"}}}, + function=lambda x: "result", + ) + ] + result = provider._convert_tools(tools) + + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["function"]["name"] == "test_tool" + assert result[0]["function"]["description"] == "A test tool" + assert result[0]["function"]["parameters"] == tools[0].input_schema + + +def test_build_request_kwargs_basic(provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings(temperature=0.5, max_tokens=1000) + + kwargs = provider._build_request_kwargs(messages, None, None, settings) + + assert kwargs["model"] == "gpt-4o" + assert kwargs["temperature"] == 0.5 + assert kwargs["max_tokens"] == 1000 + assert len(kwargs["messages"]) == 1 + + +def test_build_request_kwargs_with_system_prompt_standard_model(provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings() + + kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings) + + # For gpt-4o, system prompt becomes system message + assert kwargs["messages"][0]["role"] == "system" + assert kwargs["messages"][0]["content"] == "system prompt" + + +def test_build_request_kwargs_with_system_prompt_reasoning_model( + reasoning_provider, +): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings() + + kwargs = reasoning_provider._build_request_kwargs( + messages, "system prompt", None, settings + ) + + # For o1 models, system prompt becomes developer message + assert kwargs["messages"][0]["role"] == "developer" + assert kwargs["messages"][0]["content"] == "system prompt" + + +def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens( + reasoning_provider, +): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings(max_tokens=2000) + + kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings) + + # Reasoning models use max_completion_tokens + assert "max_completion_tokens" in kwargs + assert kwargs["max_completion_tokens"] == 2000 + assert "max_tokens" not in kwargs + + +def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings(temperature=0.7) + + kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings) + + # Reasoning models don't support temperature + assert "temperature" not in kwargs + assert "top_p" not in kwargs + + +def test_build_request_kwargs_with_tools(provider): + messages = [Message(role=MessageRole.USER, content="test")] + tools = [ + ToolDefinition( + name="test", + description="test", + input_schema={}, + function=lambda x: "result", + ) + ] + settings = LLMSettings() + + kwargs = provider._build_request_kwargs(messages, None, tools, settings) + + assert "tools" in kwargs + assert len(kwargs["tools"]) == 1 + assert kwargs["tool_choice"] == "auto" + + +def test_build_request_kwargs_with_stream(provider): + messages = [Message(role=MessageRole.USER, content="test")] + settings = LLMSettings() + + kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True) + + assert kwargs["stream"] is True + + +def test_parse_and_finalize_tool_call(provider): + tool_call = { + "id": "t1", + "name": "test", + "arguments": '{"key": "value"}', + } + + result = provider._parse_and_finalize_tool_call(tool_call) + + assert result["id"] == "t1" + assert result["name"] == "test" + assert result["input"] == {"key": "value"} + assert "arguments" not in result + + +def test_parse_and_finalize_tool_call_invalid_json(provider): + tool_call = { + "id": "t1", + "name": "test", + "arguments": '{"invalid json', + } + + result = provider._parse_and_finalize_tool_call(tool_call) + + # Should default to empty dict on parse error + assert result["input"] == {} + + +def test_handle_stream_chunk_text_content(provider): + chunk = Mock( + choices=[ + Mock( + delta=Mock(content="hello", tool_calls=None), + finish_reason=None, + ) + ] + ) + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 1 + assert events[0].type == "text" + assert events[0].data == "hello" + assert tool_call is None + + +def test_handle_stream_chunk_tool_call_start(provider): + function = Mock(spec=["name", "arguments"]) + function.name = "test_tool" + function.arguments = "" + + tool_call_mock = Mock(spec=["id", "function"]) + tool_call_mock.id = "t1" + tool_call_mock.function = function + + delta = Mock(spec=["content", "tool_calls"]) + delta.content = None + delta.tool_calls = [tool_call_mock] + + choice = Mock(spec=["delta", "finish_reason"]) + choice.delta = delta + choice.finish_reason = None + + chunk = Mock(spec=["choices"]) + chunk.choices = [choice] + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 0 + assert tool_call is not None + assert tool_call["id"] == "t1" + assert tool_call["name"] == "test_tool" + + +def test_handle_stream_chunk_tool_call_arguments(provider): + current_tool = {"id": "t1", "name": "test", "arguments": '{"ke'} + chunk = Mock( + choices=[ + Mock( + delta=Mock( + content=None, + tool_calls=[ + Mock( + id=None, + function=Mock(name=None, arguments='y": "val"}'), + ) + ], + ), + finish_reason=None, + ) + ] + ) + + events, tool_call = provider._handle_stream_chunk(chunk, current_tool) + + assert len(events) == 0 + assert tool_call["arguments"] == '{"key": "val"}' + + +def test_handle_stream_chunk_finish_with_tool_call(provider): + current_tool = {"id": "t1", "name": "test", "arguments": '{"key": "value"}'} + chunk = Mock( + choices=[ + Mock( + delta=Mock(content=None, tool_calls=None), + finish_reason="tool_calls", + ) + ] + ) + + events, tool_call = provider._handle_stream_chunk(chunk, current_tool) + + assert len(events) == 1 + assert events[0].type == "tool_use" + assert events[0].data["id"] == "t1" + assert events[0].data["input"] == {"key": "value"} + assert tool_call is None + + +def test_handle_stream_chunk_empty_choices(provider): + chunk = Mock(choices=[]) + + events, tool_call = provider._handle_stream_chunk(chunk, None) + + assert len(events) == 0 + assert tool_call is None + + +def test_generate_basic(provider, mock_openai_client): + messages = [Message(role=MessageRole.USER, content="test")] + + # The conftest fixture already sets up the mock response + result = provider.generate(messages) + + assert isinstance(result, str) + assert len(result) > 0 + provider.client.chat.completions.create.assert_called_once() + + +def test_stream_basic(provider, mock_openai_client): + messages = [Message(role=MessageRole.USER, content="test")] + + events = list(provider.stream(messages)) + + # Should get text events and done event + assert len(events) > 0 + text_events = [e for e in events if e.type == "text"] + assert len(text_events) > 0 + assert events[-1].type == "done" + + +@pytest.mark.asyncio +async def test_agenerate_basic(provider, mock_openai_client): + messages = [Message(role=MessageRole.USER, content="test")] + + # Mock the async client + mock_response = Mock(choices=[Mock(message=Mock(content="async response"))]) + provider.async_client.chat.completions.create = Mock(return_value=mock_response) + + result = await provider.agenerate(messages) + + assert result == "async response" + + +@pytest.mark.asyncio +async def test_astream_basic(provider, mock_openai_client): + messages = [Message(role=MessageRole.USER, content="test")] + + # Mock async streaming + async def async_stream(): + yield Mock( + choices=[ + Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None) + ] + ) + yield Mock( + choices=[ + Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop") + ] + ) + + provider.async_client.chat.completions.create = Mock(return_value=async_stream()) + + events = [] + async for event in provider.astream(messages): + events.append(event) + + assert len(events) > 0 + text_events = [e for e in events if e.type == "text"] + assert len(text_events) > 0 + + +def test_stream_with_tool_call(provider, mock_openai_client): + """Test streaming with a complete tool call.""" + + def stream_with_tool(*args, **kwargs): + if kwargs.get("stream"): + # First chunk - tool call start + function1 = Mock(spec=["name", "arguments"]) + function1.name = "test_tool" + function1.arguments = "" + + tool_call1 = Mock(spec=["id", "function"]) + tool_call1.id = "t1" + tool_call1.function = function1 + + delta1 = Mock(spec=["content", "tool_calls"]) + delta1.content = None + delta1.tool_calls = [tool_call1] + + choice1 = Mock(spec=["delta", "finish_reason"]) + choice1.delta = delta1 + choice1.finish_reason = None + + chunk1 = Mock(spec=["choices"]) + chunk1.choices = [choice1] + + # Second chunk - tool arguments + function2 = Mock(spec=["name", "arguments"]) + function2.name = None + function2.arguments = '{"arg": "val"}' + + tool_call2 = Mock(spec=["id", "function"]) + tool_call2.id = None + tool_call2.function = function2 + + delta2 = Mock(spec=["content", "tool_calls"]) + delta2.content = None + delta2.tool_calls = [tool_call2] + + choice2 = Mock(spec=["delta", "finish_reason"]) + choice2.delta = delta2 + choice2.finish_reason = None + + chunk2 = Mock(spec=["choices"]) + chunk2.choices = [choice2] + + # Third chunk - finish + delta3 = Mock(spec=["content", "tool_calls"]) + delta3.content = None + delta3.tool_calls = None + + choice3 = Mock(spec=["delta", "finish_reason"]) + choice3.delta = delta3 + choice3.finish_reason = "tool_calls" + + chunk3 = Mock(spec=["choices"]) + chunk3.choices = [choice3] + + return iter([chunk1, chunk2, chunk3]) + + provider.client.chat.completions.create.side_effect = stream_with_tool + + messages = [Message(role=MessageRole.USER, content="test")] + events = list(provider.stream(messages)) + + tool_events = [e for e in events if e.type == "tool_use"] + assert len(tool_events) == 1 + assert tool_events[0].data["id"] == "t1" + assert tool_events[0].data["name"] == "test_tool" + assert tool_events[0].data["input"] == {"arg": "val"} + + +def test_encode_image(provider): + image = Image.new("RGB", (10, 10), color="blue") + + encoded = provider.encode_image(image) + + assert isinstance(encoded, str) + assert len(encoded) > 0 + + +def test_encode_image_rgba(provider): + """RGBA images should be converted to RGB.""" + image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128)) + + encoded = provider.encode_image(image) + + assert isinstance(encoded, str) + assert len(encoded) > 0 diff --git a/tools/add_user.py b/tools/add_user.py index 4272f45..8991edc 100644 --- a/tools/add_user.py +++ b/tools/add_user.py @@ -2,21 +2,41 @@ import argparse from memory.common.db.connection import make_session -from memory.common.db.models.users import User +from memory.common.db.models.users import HumanUser, BotUser if __name__ == "__main__": args = argparse.ArgumentParser() args.add_argument("--email", type=str, required=True) - args.add_argument("--password", type=str, required=True) args.add_argument("--name", type=str, required=True) + args.add_argument("--password", type=str, required=False) + args.add_argument("--bot", action="store_true", help="Create a bot user") + args.add_argument( + "--api-key", + type=str, + required=False, + help="API key for bot user (auto-generated if not provided)", + ) args = args.parse_args() with make_session() as session: - user = User.create_with_password( - email=args.email, password=args.password, name=args.name - ) + if args.bot: + user = BotUser.create_with_api_key( + name=args.name, email=args.email, api_key=args.api_key + ) + print(f"Bot user {args.email} created with API key: {user.api_key}") + else: + if not args.password: + raise ValueError("Password required for human users") + user = HumanUser.create_with_password( + email=args.email, password=args.password, name=args.name + ) + print(f"Human user {args.email} created") + session.add(user) session.commit() - print(f"User {args.email} created") + if args.bot: + print(f"Bot user {args.email} created with API key: {user.api_key}") + else: + print(f"Human user {args.email} created")