diff --git a/db/migrations/versions/20251012_222827_add_discord_models.py b/db/migrations/versions/20251013_142101_add_discord_models.py
similarity index 79%
rename from db/migrations/versions/20251012_222827_add_discord_models.py
rename to db/migrations/versions/20251013_142101_add_discord_models.py
index deb73e1..cd054eb 100644
--- a/db/migrations/versions/20251012_222827_add_discord_models.py
+++ b/db/migrations/versions/20251013_142101_add_discord_models.py
@@ -1,8 +1,8 @@
"""add_discord_models
-Revision ID: a8c8e8b17179
+Revision ID: 7c6169fba146
Revises: c86079073c1d
-Create Date: 2025-10-12 22:28:27.856164
+Create Date: 2025-10-13 14:21:01.080948
"""
@@ -13,7 +13,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
-revision: str = "a8c8e8b17179"
+revision: str = "7c6169fba146"
down_revision: Union[str, None] = "c86079073c1d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -26,12 +26,6 @@ def upgrade() -> None:
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("member_count", sa.Integer(), nullable=True),
- sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
- sa.Column(
- "ignore_messages", sa.Boolean(), server_default="false", nullable=True
- ),
- sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
- sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
@@ -45,6 +39,17 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=True,
),
+ sa.Column(
+ "track_messages", sa.Boolean(), server_default="true", nullable=False
+ ),
+ sa.Column("ignore_messages", sa.Boolean(), nullable=True),
+ sa.Column(
+ "allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
+ ),
+ sa.Column(
+ "disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
+ ),
+ sa.Column("summary", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
@@ -59,12 +64,6 @@ def upgrade() -> None:
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("channel_type", sa.Text(), nullable=False),
- sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
- sa.Column(
- "ignore_messages", sa.Boolean(), server_default="false", nullable=True
- ),
- sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
- sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
@@ -77,6 +76,17 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=True,
),
+ sa.Column(
+ "track_messages", sa.Boolean(), server_default="true", nullable=False
+ ),
+ sa.Column("ignore_messages", sa.Boolean(), nullable=True),
+ sa.Column(
+ "allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
+ ),
+ sa.Column(
+ "disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
+ ),
+ sa.Column("summary", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["server_id"],
["discord_servers.id"],
@@ -92,12 +102,6 @@ def upgrade() -> None:
sa.Column("username", sa.Text(), nullable=False),
sa.Column("display_name", sa.Text(), nullable=True),
sa.Column("system_user_id", sa.Integer(), nullable=True),
- sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
- sa.Column(
- "ignore_messages", sa.Boolean(), server_default="false", nullable=True
- ),
- sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
- sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
@@ -110,6 +114,17 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=True,
),
+ sa.Column(
+ "track_messages", sa.Boolean(), server_default="true", nullable=False
+ ),
+ sa.Column("ignore_messages", sa.Boolean(), nullable=True),
+ sa.Column(
+ "allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
+ ),
+ sa.Column(
+ "disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
+ ),
+ sa.Column("summary", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["system_user_id"],
["users.id"],
diff --git a/db/migrations/versions/20251020_014858_seperate_user__models.py b/db/migrations/versions/20251020_014858_seperate_user__models.py
new file mode 100644
index 0000000..274a43c
--- /dev/null
+++ b/db/migrations/versions/20251020_014858_seperate_user__models.py
@@ -0,0 +1,145 @@
+"""seperate_user__models
+
+Revision ID: 35a2c1b610b6
+Revises: 7c6169fba146
+Create Date: 2025-10-20 01:48:58.537881
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "35a2c1b610b6"
+down_revision: Union[str, None] = "7c6169fba146"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ op.add_column(
+ "discord_message", sa.Column("from_id", sa.BigInteger(), nullable=False)
+ )
+ op.add_column(
+ "discord_message", sa.Column("recipient_id", sa.BigInteger(), nullable=False)
+ )
+ op.drop_index("discord_message_user_idx", table_name="discord_message")
+ op.create_index(
+ "discord_message_from_idx", "discord_message", ["from_id"], unique=False
+ )
+ op.create_index(
+ "discord_message_recipient_idx",
+ "discord_message",
+ ["recipient_id"],
+ unique=False,
+ )
+ op.drop_constraint(
+ "discord_message_discord_user_id_fkey", "discord_message", type_="foreignkey"
+ )
+ op.create_foreign_key(
+ "discord_message_from_id_fkey",
+ "discord_message",
+ "discord_users",
+ ["from_id"],
+ ["id"],
+ )
+ op.create_foreign_key(
+ "discord_message_recipient_id_fkey",
+ "discord_message",
+ "discord_users",
+ ["recipient_id"],
+ ["id"],
+ )
+ op.drop_column("discord_message", "discord_user_id")
+ op.add_column(
+ "scheduled_llm_calls",
+ sa.Column("discord_channel_id", sa.BigInteger(), nullable=True),
+ )
+ op.add_column(
+ "scheduled_llm_calls",
+ sa.Column("discord_user_id", sa.BigInteger(), nullable=True),
+ )
+ op.create_foreign_key(
+ "scheduled_llm_calls_discord_user_id_fkey",
+ "scheduled_llm_calls",
+ "discord_users",
+ ["discord_user_id"],
+ ["id"],
+ )
+ op.create_foreign_key(
+ "scheduled_llm_calls_discord_channel_id_fkey",
+ "scheduled_llm_calls",
+ "discord_channels",
+ ["discord_channel_id"],
+ ["id"],
+ )
+ op.drop_column("scheduled_llm_calls", "discord_user")
+ op.drop_column("scheduled_llm_calls", "discord_channel")
+ op.add_column(
+ "users",
+ sa.Column("user_type", sa.String(), nullable=False, server_default="human"),
+ )
+ op.add_column("users", sa.Column("api_key", sa.String(), nullable=True))
+ op.alter_column("users", "password_hash", existing_type=sa.VARCHAR(), nullable=True)
+ op.create_unique_constraint("users_api_key_key", "users", ["api_key"])
+ op.drop_column("users", "discord_user_id")
+
+
+def downgrade() -> None:
+ op.add_column(
+ "users",
+ sa.Column("discord_user_id", sa.VARCHAR(), autoincrement=False, nullable=True),
+ )
+ op.drop_constraint("users_api_key_key", "users", type_="unique")
+ op.alter_column(
+ "users", "password_hash", existing_type=sa.VARCHAR(), nullable=False
+ )
+ op.drop_column("users", "api_key")
+ op.drop_column("users", "user_type")
+ op.add_column(
+ "scheduled_llm_calls",
+ sa.Column("discord_channel", sa.VARCHAR(), autoincrement=False, nullable=True),
+ )
+ op.add_column(
+ "scheduled_llm_calls",
+ sa.Column("discord_user", sa.VARCHAR(), autoincrement=False, nullable=True),
+ )
+ op.drop_constraint(
+ "scheduled_llm_calls_discord_user_id_fkey",
+ "scheduled_llm_calls",
+ type_="foreignkey",
+ )
+ op.drop_constraint(
+ "scheduled_llm_calls_discord_channel_id_fkey",
+ "scheduled_llm_calls",
+ type_="foreignkey",
+ )
+ op.drop_column("scheduled_llm_calls", "discord_user_id")
+ op.drop_column("scheduled_llm_calls", "discord_channel_id")
+ op.add_column(
+ "discord_message",
+ sa.Column("discord_user_id", sa.BIGINT(), autoincrement=False, nullable=False),
+ )
+ op.drop_constraint(
+ "discord_message_from_id_fkey", "discord_message", type_="foreignkey"
+ )
+ op.drop_constraint(
+ "discord_message_recipient_id_fkey", "discord_message", type_="foreignkey"
+ )
+ op.create_foreign_key(
+ "discord_message_discord_user_id_fkey",
+ "discord_message",
+ "discord_users",
+ ["discord_user_id"],
+ ["id"],
+ )
+ op.drop_index("discord_message_recipient_idx", table_name="discord_message")
+ op.drop_index("discord_message_from_idx", table_name="discord_message")
+ op.create_index(
+ "discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
+ )
+ op.drop_column("discord_message", "recipient_id")
+ op.drop_column("discord_message", "from_id")
diff --git a/docker-compose.yaml b/docker-compose.yaml
index dee38e3..2bc0945 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -50,6 +50,8 @@ x-worker-base: &worker-base
OPENAI_API_KEY_FILE: /run/secrets/openai_key
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
+ DISCORD_COLLECTOR_SERVER_URL: ingest-hub
+ DISCORD_COLLECTOR_PORT: 8003
secrets: [ postgres_password, openai_key, anthropic_key, ssh_private_key, ssh_public_key, ssh_known_hosts ]
read_only: true
tmpfs:
@@ -183,7 +185,6 @@ services:
dockerfile: docker/ingest_hub/Dockerfile
environment:
<<: *worker-env
- DISCORD_API_PORT: 8000
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
DISCORD_NOTIFICATIONS_ENABLED: true
DISCORD_COLLECTOR_ENABLED: true
diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py
index f7cda67..aaa7a25 100644
--- a/src/memory/api/MCP/base.py
+++ b/src/memory/api/MCP/base.py
@@ -25,7 +25,7 @@ from memory.api.MCP.oauth_provider import (
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import OAuthState, UserSession
-from memory.common.db.models.users import User
+from memory.common.db.models.users import HumanUser
logger = logging.getLogger(__name__)
@@ -126,7 +126,11 @@ async def handle_login(request: Request):
key: value for key, value in form.items() if key not in ["email", "password"]
}
with make_session() as session:
- user = session.query(User).filter(User.email == form.get("email")).first()
+ user = (
+ session.query(HumanUser)
+ .filter(HumanUser.email == form.get("email"))
+ .first()
+ )
if not user or not user.is_valid_password(str(form.get("password", ""))):
logger.warning("Login failed - invalid credentials")
return login_form(request, oauth_params, "Invalid email or password")
@@ -144,11 +148,7 @@ def get_current_user() -> dict:
return {"authenticated": False}
with make_session() as session:
- user_session = (
- session.query(UserSession)
- .filter(UserSession.id == access_token.token)
- .first()
- )
+ user_session = session.query(UserSession).get(access_token.token)
if user_session and user_session.user:
user_info = user_session.user.serialize()
diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py
index ca0fd39..8ba2c63 100644
--- a/src/memory/api/MCP/oauth_provider.py
+++ b/src/memory/api/MCP/oauth_provider.py
@@ -21,6 +21,7 @@ from memory.common.db.models.users import (
OAuthRefreshToken,
OAuthState,
User,
+ BotUser,
UserSession,
)
from memory.common.db.models.users import (
@@ -92,7 +93,7 @@ def create_oauth_token(
"""Create an OAuth token response."""
return OAuthToken(
access_token=access_token,
- token_type="bearer",
+ token_type="Bearer",
expires_in=ACCESS_TOKEN_LIFETIME,
refresh_token=refresh_token,
scope=" ".join(scopes),
@@ -310,26 +311,37 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
return token
async def load_access_token(self, token: str) -> Optional[AccessToken]:
- """Load and validate an access token."""
+ """Load and validate an access token (or bot API key)."""
with make_session() as session:
- now = datetime.now(timezone.utc).replace(
- tzinfo=None
- ) # Make naive for DB comparison
-
- # Query for active (non-expired) session
+ # Try as OAuth access token first
user_session = session.query(UserSession).get(token)
- if not user_session:
- return None
+ if user_session:
+ now = datetime.now(timezone.utc).replace(
+ tzinfo=None
+ ) # Make naive for DB comparison
- if user_session.expires_at < now:
- return None
+ if user_session.expires_at < now:
+ return None
- return AccessToken(
- token=token,
- client_id=user_session.oauth_state.client_id,
- scopes=user_session.oauth_state.scopes,
- expires_at=int(user_session.expires_at.timestamp()),
- )
+ return AccessToken(
+ token=token,
+ client_id=user_session.oauth_state.client_id,
+ scopes=user_session.oauth_state.scopes,
+ expires_at=int(user_session.expires_at.timestamp()),
+ )
+
+ # Try as bot API key
+ bot = session.query(User).filter(User.api_key == token).first()
+ if bot:
+ logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
+ return AccessToken(
+ token=token,
+ client_id=cast(str, bot.name or bot.email),
+ scopes=["read", "write"], # Bots get full access
+ expires_at=2147483647, # Far future (2038)
+ )
+
+ return None
async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str
diff --git a/src/memory/api/MCP/schedules.py b/src/memory/api/MCP/schedules.py
index 78740b6..169fe07 100644
--- a/src/memory/api/MCP/schedules.py
+++ b/src/memory/api/MCP/schedules.py
@@ -9,7 +9,9 @@ from typing import Any
from memory.api.MCP.base import get_current_user
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
+from memory.common.db.models.discord import DiscordChannel, DiscordUser
from memory.api.MCP.base import mcp
+from memory.discord.schedule import schedule_discord_message
logger = logging.getLogger(__name__)
@@ -17,7 +19,7 @@ logger = logging.getLogger(__name__)
@mcp.tool()
async def schedule_message(
scheduled_time: str,
- message: str | None = None,
+ message: str,
model: str | None = None,
topic: str | None = None,
discord_channel: str | None = None,
@@ -56,7 +58,8 @@ async def schedule_message(
if not user_id:
raise ValueError("User not found")
- discord_user = current_user.get("user", {}).get("discord_user_id")
+ discord_users = current_user.get("user", {}).get("discord_users")
+ discord_user = discord_users and next(iter(discord_users.keys()), None)
if not discord_user and not discord_channel:
raise ValueError("Either discord_user or discord_channel must be provided")
@@ -69,27 +72,20 @@ async def schedule_message(
except ValueError:
raise ValueError("Invalid datetime format")
- # Validate that the scheduled time is in the future
- # Compare with naive datetime since we store naive in the database
- current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
- if scheduled_dt <= current_time_naive:
- raise ValueError("Scheduled time must be in the future")
-
with make_session() as session:
- # Create the scheduled call
- scheduled_call = ScheduledLLMCall(
- user_id=user_id,
+ scheduled_call = schedule_discord_message(
+ session=session,
scheduled_time=scheduled_dt,
message=message,
- topic=topic,
+ user_id=current_user.get("user", {}).get("user_id"),
model=model,
- system_prompt=system_prompt,
+ topic=topic,
discord_channel=discord_channel,
discord_user=discord_user,
- data=metadata or {},
+ system_prompt=system_prompt,
+ metadata=metadata,
)
- session.add(scheduled_call)
session.commit()
return {
diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py
index a86edc0..14a79ca 100644
--- a/src/memory/api/auth.py
+++ b/src/memory/api/auth.py
@@ -7,7 +7,7 @@ from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session
from memory.common.db.connection import get_session, make_session
-from memory.common.db.models.users import User, UserSession
+from memory.common.db.models.users import User, HumanUser, BotUser, UserSession
logger = logging.getLogger(__name__)
@@ -91,14 +91,14 @@ def get_current_user(request: Request, db: DBSession = Depends(get_session)) ->
return user
-def create_user(email: str, password: str, name: str, db: DBSession) -> User:
- """Create a new user"""
+def create_user(email: str, password: str, name: str, db: DBSession) -> HumanUser:
+ """Create a new human user"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == email).first()
if existing_user:
raise HTTPException(status_code=400, detail="User already exists")
- user = User.create_with_password(email, name, password)
+ user = HumanUser.create_with_password(email, name, password)
db.add(user)
db.commit()
db.refresh(user)
@@ -106,14 +106,19 @@ def create_user(email: str, password: str, name: str, db: DBSession) -> User:
return user
-def authenticate_user(email: str, password: str, db: DBSession) -> User | None:
- """Authenticate a user by email and password"""
- user = db.query(User).filter(User.email == email).first()
+def authenticate_user(email: str, password: str, db: DBSession) -> HumanUser | None:
+ """Authenticate a human user by email and password"""
+ user = db.query(HumanUser).filter(HumanUser.email == email).first()
if user and user.is_valid_password(password):
return user
return None
+def authenticate_bot(api_key: str, db: DBSession) -> BotUser | None:
+ """Authenticate a bot by API key"""
+ return db.query(BotUser).filter(BotUser.api_key == api_key).first()
+
+
@router.api_route("/logout", methods=["GET", "POST"])
def logout(request: Request, db: DBSession = Depends(get_session)):
"""Logout and clear session"""
diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py
index 7b1f778..140fa62 100644
--- a/src/memory/common/db/models/__init__.py
+++ b/src/memory/common/db/models/__init__.py
@@ -30,6 +30,11 @@ from memory.common.db.models.source_items import (
NotePayload,
ForumPostPayload,
)
+from memory.common.db.models.discord import (
+ DiscordServer,
+ DiscordChannel,
+ DiscordUser,
+)
from memory.common.db.models.observations import (
ObservationContradiction,
ReactionPattern,
@@ -41,12 +46,12 @@ from memory.common.db.models.sources import (
Book,
ArticleFeed,
EmailAccount,
- DiscordServer,
- DiscordChannel,
- DiscordUser,
)
from memory.common.db.models.users import (
User,
+ HumanUser,
+ BotUser,
+ DiscordBotUser,
UserSession,
OAuthClientInformation,
OAuthState,
@@ -103,6 +108,9 @@ __all__ = [
"DiscordUser",
# Users
"User",
+ "HumanUser",
+ "BotUser",
+ "DiscordBotUser",
"UserSession",
"OAuthClientInformation",
"OAuthState",
diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py
new file mode 100644
index 0000000..544ffb0
--- /dev/null
+++ b/src/memory/common/db/models/discord.py
@@ -0,0 +1,121 @@
+"""
+Database models for the Discord system.
+"""
+
+import textwrap
+
+from sqlalchemy import (
+ ARRAY,
+ BigInteger,
+ Boolean,
+ Column,
+ DateTime,
+ ForeignKey,
+ Index,
+ Integer,
+ Text,
+ func,
+)
+from sqlalchemy.orm import relationship
+
+from memory.common.db.models.base import Base
+
+
+class MessageProcessor:
+ track_messages = Column(Boolean, nullable=False, server_default="true")
+ ignore_messages = Column(Boolean, nullable=True, default=False)
+
+ allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
+ disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
+
+ summary = Column(
+ Text,
+ nullable=True,
+ doc=textwrap.dedent(
+ """
+ A summary of this processor, made by and for AI systems.
+
+ The idea here is that AI systems can use this summary to keep notes on the given processor.
+ These should automatically be injected into the context of the messages that are processed by this processor.
+ """
+ ),
+ )
+
+ def as_xml(self) -> str:
+ return (
+ textwrap.dedent("""
+ <{type}>
+ {name}
+ {summary}
+ {type}>
+ """)
+ .format(
+ type=self.__class__.__tablename__[8:], # type: ignore
+ name=getattr(self, "name", None) or getattr(self, "username", None),
+ summary=self.summary,
+ )
+ .strip()
+ )
+
+
+class DiscordServer(Base, MessageProcessor):
+ """Discord server configuration and metadata"""
+
+ __tablename__ = "discord_servers"
+
+ id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
+ name = Column(Text, nullable=False)
+ description = Column(Text)
+ member_count = Column(Integer)
+
+ # Collection settings
+ last_sync_at = Column(DateTime(timezone=True))
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
+ updated_at = Column(DateTime(timezone=True), server_default=func.now())
+
+ channels = relationship(
+ "DiscordChannel", back_populates="server", cascade="all, delete-orphan"
+ )
+
+ __table_args__ = (
+ Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
+ )
+
+
+class DiscordChannel(Base, MessageProcessor):
+ """Discord channel metadata and configuration"""
+
+ __tablename__ = "discord_channels"
+
+ id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
+ server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
+ name = Column(Text, nullable=False)
+ channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
+
+ # Collection settings (null = inherit from server)
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
+ updated_at = Column(DateTime(timezone=True), server_default=func.now())
+
+ server = relationship("DiscordServer", back_populates="channels")
+ __table_args__ = (Index("discord_channels_server_idx", "server_id"),)
+
+
+class DiscordUser(Base, MessageProcessor):
+ """Discord user metadata and preferences"""
+
+ __tablename__ = "discord_users"
+
+ id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
+ username = Column(Text, nullable=False)
+ display_name = Column(Text)
+
+ # Link to system user if registered
+ system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
+
+ # Basic DM settings
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
+ updated_at = Column(DateTime(timezone=True), server_default=func.now())
+
+ system_user = relationship("User", back_populates="discord_users")
+
+ __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
diff --git a/src/memory/common/db/models/scheduled_calls.py b/src/memory/common/db/models/scheduled_calls.py
index bbd23d1..75d2184 100644
--- a/src/memory/common/db/models/scheduled_calls.py
+++ b/src/memory/common/db/models/scheduled_calls.py
@@ -7,6 +7,7 @@ from sqlalchemy import (
String,
DateTime,
ForeignKey,
+ BigInteger,
JSON,
Text,
)
@@ -37,8 +38,10 @@ class ScheduledLLMCall(Base):
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
# Discord configuration
- discord_channel = Column(String, nullable=True)
- discord_user = Column(String, nullable=True)
+ discord_channel_id = Column(
+ BigInteger, ForeignKey("discord_channels.id"), nullable=True
+ )
+ discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=True)
# Execution status and results
status = Column(
@@ -55,6 +58,8 @@ class ScheduledLLMCall(Base):
# Relationships
user = relationship("User")
+ discord_channel = relationship("DiscordChannel", foreign_keys=[discord_channel_id])
+ discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
def serialize(self) -> Dict[str, Any]:
def print_datetime(dt: datetime | None) -> str | None:
@@ -73,8 +78,8 @@ class ScheduledLLMCall(Base):
"message": self.message,
"system_prompt": self.system_prompt,
"allowed_tools": self.allowed_tools,
- "discord_channel": self.discord_channel,
- "discord_user": self.discord_user,
+ "discord_channel": self.discord_channel and self.discord_channel.name,
+ "discord_user": self.discord_user and self.discord_user.username,
"status": self.status,
"response": self.response,
"error_message": self.error_message,
diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py
index 0b4ca96..fd0eac1 100644
--- a/src/memory/common/db/models/source_items.py
+++ b/src/memory/common/db/models/source_items.py
@@ -286,7 +286,8 @@ class DiscordMessage(SourceItem):
sent_at = Column(DateTime(timezone=True), nullable=False)
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
- discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
+ from_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
+ recipient_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
# Discord-specific metadata
@@ -303,11 +304,33 @@ class DiscordMessage(SourceItem):
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
server = relationship("DiscordServer", foreign_keys=[server_id])
- discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
+ from_user = relationship("DiscordUser", foreign_keys=[from_id])
+ recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
+
+ @property
+ def allowed_tools(self) -> list[str]:
+ return (
+ (self.channel.allowed_tools if self.channel else [])
+ + (self.from_user.allowed_tools if self.from_user else [])
+ + (self.server.allowed_tools if self.server else [])
+ )
+
+ @property
+ def disallowed_tools(self) -> list[str]:
+ return (
+ (self.channel.disallowed_tools if self.channel else [])
+ + (self.from_user.disallowed_tools if self.from_user else [])
+ + (self.server.disallowed_tools if self.server else [])
+ )
+
+ def tool_allowed(self, tool: str) -> bool:
+ return not (self.disallowed_tools and tool in self.disallowed_tools) and (
+ not self.allowed_tools or tool in self.allowed_tools
+ )
@property
def title(self) -> str:
- return f"{self.discord_user.username}: {self.content}"
+ return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
__mapper_args__ = {
"polymorphic_identity": "discord_message",
@@ -320,7 +343,8 @@ class DiscordMessage(SourceItem):
"server_id",
"channel_id",
),
- Index("discord_message_user_idx", "discord_user_id"),
+ Index("discord_message_from_idx", "from_id"),
+ Index("discord_message_recipient_idx", "recipient_id"),
)
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
diff --git a/src/memory/common/db/models/sources.py b/src/memory/common/db/models/sources.py
index 1cbb9e4..77e0c3e 100644
--- a/src/memory/common/db/models/sources.py
+++ b/src/memory/common/db/models/sources.py
@@ -125,74 +125,3 @@ class EmailAccount(Base):
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)
-
-
-class MessageProcessor:
- track_messages = Column(Boolean, nullable=False, server_default="true")
- ignore_messages = Column(Boolean, nullable=True, default=False)
-
- allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
- disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
-
-
-class DiscordServer(Base, MessageProcessor):
- """Discord server configuration and metadata"""
-
- __tablename__ = "discord_servers"
-
- id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
- name = Column(Text, nullable=False)
- description = Column(Text)
- member_count = Column(Integer)
-
- # Collection settings
- last_sync_at = Column(DateTime(timezone=True))
- created_at = Column(DateTime(timezone=True), server_default=func.now())
- updated_at = Column(DateTime(timezone=True), server_default=func.now())
-
- channels = relationship(
- "DiscordChannel", back_populates="server", cascade="all, delete-orphan"
- )
-
- __table_args__ = (
- Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
- )
-
-
-class DiscordChannel(Base, MessageProcessor):
- """Discord channel metadata and configuration"""
-
- __tablename__ = "discord_channels"
-
- id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
- server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
- name = Column(Text, nullable=False)
- channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
-
- # Collection settings (null = inherit from server)
- created_at = Column(DateTime(timezone=True), server_default=func.now())
- updated_at = Column(DateTime(timezone=True), server_default=func.now())
-
- server = relationship("DiscordServer", back_populates="channels")
- __table_args__ = (Index("discord_channels_server_idx", "server_id"),)
-
-
-class DiscordUser(Base, MessageProcessor):
- """Discord user metadata and preferences"""
-
- __tablename__ = "discord_users"
-
- id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
- username = Column(Text, nullable=False)
- display_name = Column(Text)
-
- # Link to system user if registered
- system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
-
- # Basic DM settings
- created_at = Column(DateTime(timezone=True), server_default=func.now())
- updated_at = Column(DateTime(timezone=True), server_default=func.now())
-
- system_user = relationship("User", back_populates="discord_users")
-
- __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py
index 8228cee..09f8dd8 100644
--- a/src/memory/common/db/models/users.py
+++ b/src/memory/common/db/models/users.py
@@ -2,7 +2,6 @@ import hashlib
import secrets
from typing import cast
import uuid
-from datetime import datetime, timezone
from sqlalchemy.orm import Session
from memory.common.db.models.base import Base
from sqlalchemy import (
@@ -14,6 +13,7 @@ from sqlalchemy import (
Boolean,
ARRAY,
Numeric,
+ CheckConstraint,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
@@ -36,12 +36,21 @@ def verify_password(password: str, password_hash: str) -> bool:
class User(Base):
__tablename__ = "users"
+ __table_args__ = (
+ CheckConstraint(
+ "password_hash IS NOT NULL OR api_key IS NOT NULL",
+ name="user_has_auth_method",
+ ),
+ )
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False)
email = Column(String, nullable=False, unique=True)
- password_hash = Column(String, nullable=False)
- discord_user_id = Column(String, nullable=True)
+ user_type = Column(String, nullable=False) # Discriminator column
+
+ # Make these nullable since subclasses will use them selectively
+ password_hash = Column(String, nullable=True)
+ api_key = Column(String, nullable=True, unique=True)
# Relationship to sessions
sessions = relationship(
@@ -52,22 +61,86 @@ class User(Base):
)
discord_users = relationship("DiscordUser", back_populates="system_user")
+ __mapper_args__ = {
+ "polymorphic_on": user_type,
+ "polymorphic_identity": "user",
+ }
+
def serialize(self) -> dict:
return {
"user_id": self.id,
"name": self.name,
"email": self.email,
- "discord_user_id": self.discord_user_id,
+ "user_type": self.user_type,
+ "discord_users": {
+ discord_user.id: discord_user.username
+ for discord_user in self.discord_users
+ },
}
+
+class HumanUser(User):
+ """Human user with password authentication"""
+
+ __mapper_args__ = {
+ "polymorphic_identity": "human",
+ }
+
def is_valid_password(self, password: str) -> bool:
"""Check if the provided password is valid for this user"""
return verify_password(password, cast(str, self.password_hash))
@classmethod
- def create_with_password(cls, email: str, name: str, password: str) -> "User":
- """Create a new user with a hashed password"""
- return cls(email=email, name=name, password_hash=hash_password(password))
+ def create_with_password(cls, email: str, name: str, password: str) -> "HumanUser":
+ """Create a new human user with a hashed password"""
+ return cls(
+ email=email,
+ name=name,
+ password_hash=hash_password(password),
+ user_type="human",
+ )
+
+
+class BotUser(User):
+ """Bot user with API key authentication"""
+
+ __mapper_args__ = {
+ "polymorphic_identity": "bot",
+ }
+
+ @classmethod
+ def create_with_api_key(
+ cls, name: str, email: str, api_key: str | None = None
+ ) -> "BotUser":
+ """Create a new bot user with an API key"""
+ if api_key is None:
+ api_key = f"bot_{secrets.token_hex(32)}"
+ return cls(
+ name=name,
+ email=email,
+ api_key=api_key,
+ user_type=cls.__mapper_args__["polymorphic_identity"],
+ )
+
+
+class DiscordBotUser(BotUser):
+ """Bot user with API key authentication"""
+
+ __mapper_args__ = {
+ "polymorphic_identity": "discord_bot",
+ }
+
+ @classmethod
+ def create_with_api_key(
+ cls,
+ discord_users: list,
+ name: str,
+ email: str,
+ api_key: str | None = None,
+ ) -> "DiscordBotUser":
+ bot = super().create_with_api_key(name, email, api_key)
+ bot.discord_users = discord_users
+ return bot
class UserSession(Base):
diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py
index 92482bc..2a1a026 100644
--- a/src/memory/common/discord.py
+++ b/src/memory/common/discord.py
@@ -25,7 +25,7 @@ def send_dm(user_identifier: str, message: str) -> bool:
try:
response = requests.post(
f"{get_api_url()}/send_dm",
- json={"user_identifier": user_identifier, "message": message},
+ json={"user": user_identifier, "message": message},
timeout=10,
)
response.raise_for_status()
@@ -37,6 +37,24 @@ def send_dm(user_identifier: str, message: str) -> bool:
return False
+def send_to_channel(channel_name: str, message: str) -> bool:
+ """Send a DM via the Discord collector API"""
+ try:
+ response = requests.post(
+ f"{get_api_url()}/send_channel",
+ json={"channel_name": channel_name, "message": message},
+ timeout=10,
+ )
+ response.raise_for_status()
+ result = response.json()
+ print("Result", result)
+ return result.get("success", False)
+
+ except requests.RequestException as e:
+ logger.error(f"Failed to send to channel {channel_name}: {e}")
+ return False
+
+
def broadcast_message(channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API"""
try:
diff --git a/src/memory/common/llms/__init__.py b/src/memory/common/llms/__init__.py
index b2d37cc..b84e337 100644
--- a/src/memory/common/llms/__init__.py
+++ b/src/memory/common/llms/__init__.py
@@ -81,3 +81,28 @@ def truncate(content: str, target_tokens: int) -> str:
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content
+
+
+# bla = 1
+# from memory.common.llms import *
+# from memory.common.llms.tools.discord import make_discord_tools
+# from memory.common.db.connection import make_session
+# from memory.common.db.models import *
+
+# model = "anthropic/claude-sonnet-4-5"
+# provider = create_provider(model=model)
+# with make_session() as session:
+# bot = session.query(DiscordBotUser).first()
+# server = session.query(DiscordServer).first()
+# channel = server.channels[0]
+# tools = make_discord_tools(bot, None, channel, model)
+
+# def demo(msg: str):
+# messages = [
+# Message(
+# role=MessageRole.USER,
+# content=msg,
+# )
+# ]
+# for m in provider.stream_with_tools(messages, tools):
+# print(m)
diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py
index 1969bfb..938d552 100644
--- a/src/memory/common/llms/anthropic_provider.py
+++ b/src/memory/common/llms/anthropic_provider.py
@@ -333,7 +333,6 @@ class AnthropicProvider(BaseLLMProvider):
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
- print(kwargs)
try:
with self.client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None
diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py
index 113e71f..6238daa 100644
--- a/src/memory/common/llms/base.py
+++ b/src/memory/common/llms/base.py
@@ -599,6 +599,9 @@ class BaseLLMProvider(ABC):
tool_calls=tool_calls or None,
)
+ def as_messages(self, messages) -> list[Message]:
+ return [Message.user(text=msg) for msg in messages]
+
def create_provider(
model: str | None = None,
diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py
index aa5beb3..594b459 100644
--- a/src/memory/common/llms/openai_provider.py
+++ b/src/memory/common/llms/openai_provider.py
@@ -150,7 +150,7 @@ class OpenAIProvider(BaseLLMProvider):
def _convert_tools(
self, tools: list[ToolDefinition] | None
- ) -> Optional[list[dict[str, Any]]]:
+ ) -> list[dict[str, Any]] | None:
"""
Convert our tool definitions to OpenAI format.
@@ -179,7 +179,7 @@ class OpenAIProvider(BaseLLMProvider):
self,
messages: list[Message],
system_prompt: str | None,
- tools: Optional[list[ToolDefinition]],
+ tools: list[ToolDefinition] | None,
settings: LLMSettings,
stream: bool = False,
) -> dict[str, Any]:
@@ -270,7 +270,7 @@ class OpenAIProvider(BaseLLMProvider):
self,
chunk: Any,
current_tool_call: dict[str, Any] | None,
- ) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]:
+ ) -> tuple[list[StreamEvent], dict[str, Any] | None]:
"""
Handle a single streaming chunk and return events and updated tool state.
@@ -325,9 +325,9 @@ class OpenAIProvider(BaseLLMProvider):
def generate(
self,
messages: list[Message],
- system_prompt: Optional[str] = None,
- tools: Optional[list[ToolDefinition]] = None,
- settings: Optional[LLMSettings] = None,
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
@@ -374,9 +374,9 @@ class OpenAIProvider(BaseLLMProvider):
async def agenerate(
self,
messages: list[Message],
- system_prompt: Optional[str] = None,
- tools: Optional[list[ToolDefinition]] = None,
- settings: Optional[LLMSettings] = None,
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
@@ -394,9 +394,9 @@ class OpenAIProvider(BaseLLMProvider):
async def astream(
self,
messages: list[Message],
- system_prompt: Optional[str] = None,
- tools: Optional[list[ToolDefinition]] = None,
- settings: Optional[LLMSettings] = None,
+ system_prompt: str | None = None,
+ tools: list[ToolDefinition] | None = None,
+ settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
@@ -406,7 +406,7 @@ class OpenAIProvider(BaseLLMProvider):
try:
stream = await self.async_client.chat.completions.create(**kwargs)
- current_tool_call: Optional[dict[str, Any]] = None
+ current_tool_call: dict[str, Any] | None = None
async for chunk in stream:
events, current_tool_call = self._handle_stream_chunk(
diff --git a/src/memory/common/llms/tools/discord.py b/src/memory/common/llms/tools/discord.py
new file mode 100644
index 0000000..8492be3
--- /dev/null
+++ b/src/memory/common/llms/tools/discord.py
@@ -0,0 +1,231 @@
+"""Discord tool for interacting with Discord."""
+
+import textwrap
+from datetime import datetime
+from typing import Literal, cast
+from memory.discord.messages import (
+ upsert_scheduled_message,
+ comm_channel_prompt,
+ previous_messages,
+)
+from sqlalchemy import BigInteger
+from memory.common.db.connection import make_session
+from memory.common.db.models import (
+ DiscordServer,
+ DiscordChannel,
+ DiscordUser,
+ BotUser,
+)
+from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
+
+
+UpdateSummaryType = Literal["server", "channel", "user"]
+
+
+def handle_update_summary_call(
+ type: UpdateSummaryType, item_id: BigInteger
+) -> ToolHandler:
+ models = {
+ "server": DiscordServer,
+ "channel": DiscordChannel,
+ "user": DiscordUser,
+ }
+
+ def handler(input: ToolInput = None) -> str:
+ if isinstance(input, dict):
+ summary = input.get("summary") or str(input)
+ else:
+ summary = str(input)
+
+ try:
+ with make_session() as session:
+ model = models[type]
+ model = session.get(model, item_id)
+ model.summary = summary # type: ignore
+ session.commit()
+ except Exception as e:
+ return f"Error updating summary: {e}"
+ return "Updated summary"
+
+ handler.__doc__ = textwrap.dedent("""
+ Handle a {type} summary update tool call.
+
+ Args:
+ summary: The new summary of the Discord {type}
+
+ Returns:
+ Response string
+ """).format(type=type)
+ return handler
+
+
+def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefinition:
+ return ToolDefinition(
+ name=f"update_{type}_summary",
+ description=textwrap.dedent("""
+ Use this to update the summary of this Discord {type} that is added to your context.
+
+ This will overwrite the previous summary.
+ """).format(type=type),
+ input_schema={
+ "type": "object",
+ "properties": {
+ "summary": {
+ "type": "string",
+ "description": f"The new summary of the Discord {type}",
+ }
+ },
+ "required": [],
+ },
+ function=handle_update_summary_call(type, item_id),
+ )
+
+
+def schedule_message(
+ user_id: int,
+ user: int | None,
+ channel: int | None,
+ model: str,
+ message: str,
+ date_time: datetime,
+) -> str:
+ with make_session() as session:
+ call = upsert_scheduled_message(
+ session,
+ scheduled_time=date_time,
+ message=message,
+ user_id=user_id,
+ model=model,
+ discord_user=user,
+ discord_channel=channel,
+ system_prompt=comm_channel_prompt(session, user, channel),
+ )
+
+ session.commit()
+ return cast(str, call.id)
+
+
+def make_message_scheduler(
+ bot: BotUser, user: int | None, channel: int | None, model: str
+) -> ToolDefinition:
+ bot_id = cast(int, bot.id)
+ if user:
+ channel_type = "from your chat with this user"
+ elif channel:
+ channel_type = "in this channel"
+ else:
+ raise ValueError("Either user or channel must be provided")
+
+ def handler(input: ToolInput) -> str:
+ if not isinstance(input, dict):
+ raise ValueError("Input must be a dictionary")
+
+ try:
+ time = datetime.fromisoformat(input["date_time"])
+ except ValueError:
+ raise ValueError("Invalid date time format")
+ except KeyError:
+ raise ValueError("Date time is required")
+
+ return schedule_message(bot_id, user, channel, model, input["message"], time)
+
+ return ToolDefinition(
+ name="schedule_message",
+ description=textwrap.dedent("""
+ Use this to schedule a message to be sent to yourself.
+
+ At the specified date and time, your message will be sent to you, along with the most
+ recent messages {channel_type}.
+
+ Normally you will be called with any incoming messages. But sometimes you might want to be
+ able to trigger a call to yourself at a specific time, rather than waiting for the next call.
+ This tool allows you to do that.
+ So for example, if you were chatting with a Discord user, and you ask a question which needs to
+ be answered right away, you can use this tool to schedule a check in 5 minutes time, to remind
+ the user to answer the question.
+ """).format(channel_type=channel_type),
+ input_schema={
+ "type": "object",
+ "properties": {
+ "message": {
+ "type": "string",
+ "description": "The message to send",
+ },
+ "date_time": {
+ "type": "string",
+ "description": "The date and time to send the message in ISO format (e.g., 2025-01-01T00:00:00Z)",
+ },
+ },
+ },
+ function=handler,
+ )
+
+
+def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition:
+ if user:
+ channel_type = "from your chat with this user"
+ elif channel:
+ channel_type = "in this channel"
+ else:
+ raise ValueError("Either user or channel must be provided")
+
+ def handler(input: ToolInput) -> str:
+ if not isinstance(input, dict):
+ raise ValueError("Input must be a dictionary")
+ try:
+ max_messages = int(input.get("max_messages") or 10)
+ offset = int(input.get("offset") or 0)
+ except ValueError:
+ raise ValueError("Max messages and offset must be integers")
+
+ if max_messages <= 0:
+ raise ValueError("Max messages must be greater than 0")
+ if offset < 0:
+ raise ValueError("Offset must be greater than or equal to 0")
+
+ with make_session() as session:
+ messages = previous_messages(session, user, channel, max_messages, offset)
+ return "\n\n".join([msg.title for msg in messages])
+
+ return ToolDefinition(
+ name="previous_messages",
+ description=f"Get the previous N messages {channel_type}.",
+ input_schema={
+ "type": "object",
+ "properties": {
+ "max_messages": {
+ "type": "number",
+ "description": "The maximum number of messages to return",
+ "default": 10,
+ },
+ "offset": {
+ "type": "number",
+ "description": "The number of messages to offset the result by",
+ "default": 0,
+ },
+ },
+ },
+ function=handler,
+ )
+
+
+def make_discord_tools(
+ bot: BotUser,
+ author: DiscordUser | None,
+ channel: DiscordChannel | None,
+ model: str,
+) -> dict[str, ToolDefinition]:
+ author_id = author and author.id
+ channel_id = channel and channel.id
+ tools = [
+ make_message_scheduler(bot, author_id, channel_id, model),
+ make_prev_messages_tool(author_id, channel_id),
+ make_summary_tool("channel", channel_id),
+ ]
+ if author:
+ tools += [make_summary_tool("user", author_id)]
+ if channel and channel.server:
+ tools += [
+ make_summary_tool("server", cast(BigInteger, channel.server_id)),
+ ]
+ return {tool.name: tool for tool in tools}
diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py
index daf98bb..c5b61a9 100644
--- a/src/memory/common/settings.py
+++ b/src/memory/common/settings.py
@@ -133,7 +133,7 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip()
else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
-SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
+SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
@@ -173,11 +173,14 @@ DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
+DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5")
+DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10))
+
# Discord collector settings
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
-DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8000))
-DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
+DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
+DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py
index f60a22f..c0bdc0c 100644
--- a/src/memory/discord/collector.py
+++ b/src/memory/discord/collector.py
@@ -13,7 +13,7 @@ from sqlalchemy.orm import Session, scoped_session
from memory.common import settings
from memory.common.db.connection import make_session
-from memory.common.db.models.sources import (
+from memory.common.db.models import (
DiscordServer,
DiscordChannel,
DiscordUser,
@@ -227,6 +227,7 @@ class MessageCollector(commands.Bot):
message_id=message.id,
channel_id=message.channel.id,
author_id=message.author.id,
+ recipient_id=self.user and self.user.id,
server_id=message.guild.id if message.guild else None,
content=message.content or "",
sent_at=message.created_at.isoformat(),
diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py
new file mode 100644
index 0000000..d64d02e
--- /dev/null
+++ b/src/memory/discord/messages.py
@@ -0,0 +1,205 @@
+import logging
+import textwrap
+from datetime import datetime, timezone
+from typing import Any, cast
+
+from sqlalchemy.orm import Session, scoped_session
+
+from memory.common.db.models import (
+ DiscordChannel,
+ DiscordUser,
+ ScheduledLLMCall,
+ DiscordMessage,
+)
+
+logger = logging.getLogger(__name__)
+
+DiscordEntity = DiscordChannel | DiscordUser | str | int | None
+
+
+def resolve_discord_user(
+ session: Session | scoped_session, entity: DiscordEntity
+) -> DiscordUser | None:
+ if not entity:
+ return None
+ if isinstance(entity, DiscordUser):
+ return entity
+ if isinstance(entity, int):
+ return session.get(DiscordUser, entity)
+
+ entity = session.query(DiscordUser).filter(DiscordUser.username == entity).first()
+ if not entity:
+ entity = DiscordUser(id=entity, username=entity)
+ session.add(entity)
+ return entity
+
+
+def resolve_discord_channel(
+ session: Session | scoped_session, entity: DiscordEntity
+) -> DiscordChannel | None:
+ if not entity:
+ return None
+ if isinstance(entity, DiscordChannel):
+ return entity
+ if isinstance(entity, int):
+ return session.get(DiscordChannel, entity)
+
+ return session.query(DiscordChannel).filter(DiscordChannel.name == entity).first()
+
+
+def schedule_discord_message(
+ session: Session | scoped_session,
+ scheduled_time: datetime,
+ message: str,
+ user_id: int,
+ model: str | None = None,
+ topic: str | None = None,
+ discord_user: DiscordEntity = None,
+ discord_channel: DiscordEntity = None,
+ system_prompt: str | None = None,
+ metadata: dict[str, Any] | None = None,
+) -> ScheduledLLMCall:
+ discord_user = resolve_discord_user(session, discord_user)
+ discord_channel = resolve_discord_channel(session, discord_channel)
+ if not discord_user and not discord_channel:
+ raise ValueError("Either discord_user or discord_channel must be provided")
+
+ # Validate that the scheduled time is in the future
+ # Compare with naive datetime since we store naive in the database
+ current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
+ if scheduled_time.replace(tzinfo=None) <= current_time_naive:
+ raise ValueError("Scheduled time must be in the future")
+
+ # Create the scheduled call
+ scheduled_call = ScheduledLLMCall(
+ user_id=user_id,
+ scheduled_time=scheduled_time,
+ message=message,
+ topic=topic,
+ model=model,
+ system_prompt=system_prompt,
+ discord_channel=resolve_discord_channel(session, discord_channel),
+ discord_user=resolve_discord_user(session, discord_user),
+ data=metadata or {},
+ )
+
+ session.add(scheduled_call)
+ return scheduled_call
+
+
+def upsert_scheduled_message(
+ session: Session | scoped_session,
+ scheduled_time: datetime,
+ message: str,
+ user_id: int,
+ model: str | None = None,
+ topic: str | None = None,
+ discord_user: DiscordEntity = None,
+ discord_channel: DiscordEntity = None,
+ system_prompt: str | None = None,
+ metadata: dict[str, Any] | None = None,
+) -> ScheduledLLMCall:
+ discord_user = resolve_discord_user(session, discord_user)
+ discord_channel = resolve_discord_channel(session, discord_channel)
+ prev_call = (
+ session.query(ScheduledLLMCall)
+ .filter(
+ ScheduledLLMCall.user_id == user_id,
+ ScheduledLLMCall.model == model,
+ ScheduledLLMCall.discord_user_id == (discord_user and discord_user.id),
+ ScheduledLLMCall.discord_channel_id
+ == (discord_channel and discord_channel.id),
+ )
+ .first()
+ )
+ naive_scheduled_time = scheduled_time.replace(tzinfo=None)
+ print(f"naive_scheduled_time: {naive_scheduled_time}")
+ print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
+ if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
+ prev_call.status = "cancelled" # type: ignore
+
+ return schedule_discord_message(
+ session,
+ scheduled_time,
+ message,
+ user_id=user_id,
+ model=model,
+ topic=topic,
+ discord_user=discord_user,
+ discord_channel=discord_channel,
+ system_prompt=system_prompt,
+ metadata=metadata,
+ )
+
+
+def previous_messages(
+ session: Session | scoped_session,
+ user_id: int | None,
+ channel_id: int | None,
+ max_messages: int = 10,
+ offset: int = 0,
+) -> list[DiscordMessage]:
+ messages = session.query(DiscordMessage)
+ if user_id:
+ messages = messages.filter(DiscordMessage.recipient_id == user_id)
+ if channel_id:
+ messages = messages.filter(DiscordMessage.channel_id == channel_id)
+ return list(
+ reversed(
+ messages.order_by(DiscordMessage.sent_at.desc())
+ .offset(offset)
+ .limit(max_messages)
+ .all()
+ )
+ )
+
+
+def comm_channel_prompt(
+ session: Session | scoped_session,
+ user: DiscordEntity,
+ channel: DiscordEntity,
+ max_messages: int = 10,
+) -> str:
+ user = resolve_discord_user(session, user)
+ channel = resolve_discord_channel(session, channel)
+
+ messages = previous_messages(
+ session, user and user.id, channel and channel.id, max_messages
+ )
+
+ server_context = ""
+ if channel and channel.server:
+ server_context = textwrap.dedent("""
+ Here are your previous notes on the server:
+
+ {summary}
+
+ """).format(summary=channel.server.summary)
+ if channel:
+ server_context += textwrap.dedent("""
+ Here are your previous notes on the channel:
+
+ {summary}
+
+ """).format(summary=channel.summary)
+ if messages:
+ server_context += textwrap.dedent("""
+ Here are your previous notes on the users:
+
+ {users}
+
+ """).format(
+ users="\n".join({msg.from_user.as_xml() for msg in messages}),
+ )
+
+ return textwrap.dedent("""
+ You are a bot communicating on Discord.
+
+ {server_context}
+
+ Whenever something worth remembering is said, you should add a note to the appropriate context - use
+ this to track your understanding of the conversation and those taking part in it.
+
+ You will be given the last {max_messages} messages in the conversation.
+ Please react to them appropriately. You can return an empty response if you don't have anything to say.
+ """).format(server_context=server_context, max_messages=max_messages)
diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py
index 093acd1..f146af0 100644
--- a/src/memory/workers/tasks/discord.py
+++ b/src/memory/workers/tasks/discord.py
@@ -4,25 +4,31 @@ Celery tasks for Discord message processing.
import hashlib
import logging
+import textwrap
from datetime import datetime
from typing import Any
-from memory.common.celery_app import app
-from memory.common.db.connection import make_session
-from memory.common.db.models import DiscordMessage, DiscordUser
-from memory.workers.tasks.content_processing import (
- safe_task_execution,
- check_content_exists,
- create_task_result,
- process_content_item,
-)
+from sqlalchemy import exc as sqlalchemy_exc
+from sqlalchemy.orm import Session, scoped_session
+
+from memory.common import discord, settings
from memory.common.celery_app import (
ADD_DISCORD_MESSAGE,
EDIT_DISCORD_MESSAGE,
PROCESS_DISCORD_MESSAGE,
+ app,
+)
+from memory.common.db.connection import make_session
+from memory.common.db.models import DiscordMessage, DiscordUser
+from memory.common.llms.base import create_provider
+from memory.common.llms.tools.discord import make_discord_tools
+from memory.discord.messages import comm_channel_prompt, previous_messages
+from memory.workers.tasks.content_processing import (
+ check_content_exists,
+ create_task_result,
+ process_content_item,
+ safe_task_execution,
)
-from memory.common import settings
-from sqlalchemy.orm import Session, scoped_session
logger = logging.getLogger(__name__)
@@ -32,7 +38,7 @@ def get_prev(
) -> list[str]:
prev = (
session.query(DiscordUser.username, DiscordMessage.content)
- .join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
+ .join(DiscordUser, DiscordMessage.from_id == DiscordUser.id)
.filter(
DiscordMessage.channel_id == channel_id,
DiscordMessage.sent_at < sent_at,
@@ -45,20 +51,54 @@ def get_prev(
def should_process(message: DiscordMessage) -> bool:
- return (
+ if not (
settings.DISCORD_PROCESS_MESSAGES
and settings.DISCORD_NOTIFICATIONS_ENABLED
and not (
(message.server and message.server.ignore_messages)
or (message.channel and message.channel.ignore_messages)
- or (message.discord_user and message.discord_user.ignore_messages)
+ or (message.from_user and message.from_user.ignore_messages)
)
- )
+ ):
+ return False
+
+ provider = create_provider(model=settings.SUMMARIZER_MODEL)
+ with make_session() as session:
+ system_prompt = comm_channel_prompt(
+ session, message.recipient_user, message.channel
+ )
+ messages = previous_messages(
+ session,
+ message.recipient_user and message.recipient_user.id,
+ message.channel and message.channel.id,
+ max_messages=10,
+ )
+ msg = textwrap.dedent("""
+ Should you continue the conversation with the user?
+ Please return "yes" or "no" as:
+
+ yes
+
+ or
+
+ no
+
+ """)
+ response = provider.generate(
+ messages=provider.as_messages([m.title for m in messages] + [msg]),
+ system_prompt=system_prompt,
+ )
+ return "yes" in "".join(response.lower().split())
@app.task(name=PROCESS_DISCORD_MESSAGE)
@safe_task_execution
def process_discord_message(message_id: int) -> dict[str, Any]:
+ """
+ Process a Discord message.
+
+ This task is queued by the Discord collector when messages are received.
+ """
logger.info(f"Processing Discord message {message_id}")
with make_session() as session:
@@ -71,7 +111,39 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id,
}
- print("Processing message", discord_message.id, discord_message.content)
+ tools = make_discord_tools(
+ discord_message.recipient_user,
+ discord_message.from_user,
+ discord_message.channel,
+ model=settings.DISCORD_MODEL,
+ )
+ tools = {
+ name: tool
+ for name, tool in tools.items()
+ if discord_message.tool_allowed(name)
+ }
+ system_prompt = comm_channel_prompt(
+ session, discord_message.recipient_user, discord_message.channel
+ )
+ messages = previous_messages(
+ session,
+ discord_message.recipient_user and discord_message.recipient_user.id,
+ discord_message.channel and discord_message.channel.id,
+ max_messages=10,
+ )
+ provider = create_provider(model=settings.DISCORD_MODEL)
+ turn = provider.run_with_tools(
+ messages=provider.as_messages([m.title for m in messages]),
+ tools=tools,
+ system_prompt=system_prompt,
+ max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
+ )
+ if not turn.response:
+ pass
+ elif discord_message.channel.server:
+ discord.send_to_channel(discord_message.channel.name, turn.response)
+ else:
+ discord.send_dm(discord_message.from_user.username, turn.response)
return {
"status": "processed",
@@ -88,6 +160,7 @@ def add_discord_message(
content: str,
sent_at: str,
server_id: int | None = None,
+ recipient_id: int | None = None,
message_reference_id: int | None = None,
) -> dict[str, Any]:
"""
@@ -108,7 +181,8 @@ def add_discord_message(
channel_id=channel_id,
sent_at=sent_at_dt,
server_id=server_id,
- discord_user_id=author_id,
+ from_id=author_id,
+ recipient_id=recipient_id,
message_id=message_id,
message_type="reply" if message_reference_id else "default",
reply_to_message_id=message_reference_id,
@@ -125,7 +199,15 @@ def add_discord_message(
if channel_id:
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
- result = process_content_item(discord_message, session)
+ try:
+ result = process_content_item(discord_message, session)
+ except sqlalchemy_exc.IntegrityError as e:
+ logger.error(f"Integrity error adding Discord message {message_id}: {e}")
+ return {
+ "status": "error",
+ "error": "Integrity error",
+ "message_id": message_id,
+ }
if should_process(discord_message):
process_discord_message.delay(discord_message.id)
diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py
index 3248285..3ca6cad 100644
--- a/src/memory/workers/tasks/scheduled_calls.py
+++ b/src/memory/workers/tasks/scheduled_calls.py
@@ -37,12 +37,12 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)"
- if discord_user := cast(str, scheduled_call.discord_user):
- logger.info(f"Sending DM to {discord_user}: {message}")
- discord.send_dm(discord_user, message)
- elif discord_channel := cast(str, scheduled_call.discord_channel):
- logger.info(f"Broadcasting message to {discord_channel}: {message}")
- discord.broadcast_message(discord_channel, message)
+ if discord_user := scheduled_call.discord_user:
+ logger.info(f"Sending DM to {discord_user.username}: {message}")
+ discord.send_dm(discord_user.username, message)
+ elif discord_channel := scheduled_call.discord_channel:
+ logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
+ discord.broadcast_message(discord_channel.name, message)
else:
logger.warning(
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
@@ -62,11 +62,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
with make_session() as session:
# Fetch the scheduled call
- scheduled_call = (
- session.query(ScheduledLLMCall)
- .filter(ScheduledLLMCall.id == scheduled_call_id)
- .first()
- )
+ scheduled_call = session.query(ScheduledLLMCall).get(scheduled_call_id)
if not scheduled_call:
logger.error(f"Scheduled call {scheduled_call_id} not found")
diff --git a/tests/conftest.py b/tests/conftest.py
index 175eedf..683cb4d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -254,17 +254,59 @@ def mock_openai_client():
with patch.object(openai, "OpenAI", autospec=True) as mock_client:
client = mock_client()
client.chat = Mock()
+
+ # Mock non-streaming response
client.chat.completions.create = Mock(
return_value=Mock(
choices=[
Mock(
message=Mock(
content="test summarytag1tag2"
- )
+ ),
+ finish_reason=None,
)
]
)
)
+
+ # Store original side_effect for potential override
+ def streaming_response(*args, **kwargs):
+ if kwargs.get("stream"):
+ # Return mock streaming chunks
+ return iter(
+ [
+ Mock(
+ choices=[
+ Mock(
+ delta=Mock(content="test", tool_calls=None),
+ finish_reason=None,
+ )
+ ]
+ ),
+ Mock(
+ choices=[
+ Mock(
+ delta=Mock(content=" response", tool_calls=None),
+ finish_reason="stop",
+ )
+ ]
+ ),
+ ]
+ )
+ else:
+ # Return non-streaming response
+ return Mock(
+ choices=[
+ Mock(
+ message=Mock(
+ content="test summarytag1tag2"
+ ),
+ finish_reason=None,
+ )
+ ]
+ )
+
+ client.chat.completions.create.side_effect = streaming_response
yield client
diff --git a/tests/memory/common/llms/__init__.py b/tests/memory/common/llms/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/memory/common/llms/test_anthropic_event_parsing.py b/tests/memory/common/llms/test_anthropic_event_parsing.py
new file mode 100644
index 0000000..45cbdeb
--- /dev/null
+++ b/tests/memory/common/llms/test_anthropic_event_parsing.py
@@ -0,0 +1,552 @@
+"""Comprehensive tests for Anthropic stream event parsing."""
+
+import pytest
+from unittest.mock import Mock
+
+from memory.common.llms.anthropic_provider import AnthropicProvider
+from memory.common.llms.base import StreamEvent
+
+
+@pytest.fixture
+def provider():
+ return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229")
+
+
+# Content Block Start Tests
+
+
+@pytest.mark.parametrize(
+ "block_type,block_attrs,expected_tool_use",
+ [
+ (
+ "tool_use",
+ {"id": "tool-1", "name": "search", "input": {}},
+ {
+ "id": "tool-1",
+ "name": "search",
+ "input": {},
+ "server_name": None,
+ "is_server_call": False,
+ },
+ ),
+ (
+ "mcp_tool_use",
+ {
+ "id": "mcp-1",
+ "name": "mcp_search",
+ "input": {},
+ "server_name": "mcp-server",
+ },
+ {
+ "id": "mcp-1",
+ "name": "mcp_search",
+ "input": {},
+ "server_name": "mcp-server",
+ "is_server_call": True,
+ },
+ ),
+ (
+ "server_tool_use",
+ {
+ "id": "srv-1",
+ "name": "server_action",
+ "input": {},
+ "server_name": "custom-server",
+ },
+ {
+ "id": "srv-1",
+ "name": "server_action",
+ "input": {},
+ "server_name": "custom-server",
+ "is_server_call": True,
+ },
+ ),
+ ],
+)
+def test_content_block_start_tool_types(
+ provider, block_type, block_attrs, expected_tool_use
+):
+ """Different tool types should be tracked correctly."""
+ block = Mock(spec=["type"] + list(block_attrs.keys()))
+ block.type = block_type
+ for key, value in block_attrs.items():
+ setattr(block, key, value)
+
+ event = Mock(spec=["type", "content_block"])
+ event.type = "content_block_start"
+ event.content_block = block
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+ assert tool_use == expected_tool_use
+
+
+def test_content_block_start_tool_without_input(provider):
+ """Tool use without input field should initialize as empty string."""
+ block = Mock(spec=["type", "id", "name"])
+ block.type = "tool_use"
+ block.id = "tool-2"
+ block.name = "calculate"
+
+ event = Mock(spec=["type", "content_block"])
+ event.type = "content_block_start"
+ event.content_block = block
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert tool_use["input"] == ""
+
+
+def test_content_block_start_tool_result(provider):
+ """Tool result blocks should emit tool_result event."""
+ block = Mock(spec=["tool_use_id", "content"])
+ block.tool_use_id = "tool-1"
+ block.content = "Result content"
+
+ event = Mock(spec=["type", "content_block"])
+ event.type = "content_block_start"
+ event.content_block = block
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is not None
+ assert stream_event.type == "tool_result"
+ assert stream_event.data == {"id": "tool-1", "result": "Result content"}
+
+
+@pytest.mark.parametrize(
+ "has_content_block,block_type",
+ [
+ (False, None),
+ (True, "unknown_type"),
+ ],
+)
+def test_content_block_start_ignored_cases(provider, has_content_block, block_type):
+ """Events without content_block or with unknown types should be ignored."""
+ event = Mock(spec=["type", "content_block"] if has_content_block else ["type"])
+ event.type = "content_block_start"
+
+ if has_content_block:
+ block = Mock(spec=["type"])
+ block.type = block_type
+ event.content_block = block
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+ assert tool_use is None
+
+
+# Content Block Delta Tests
+
+
+@pytest.mark.parametrize(
+ "delta_type,delta_attr,attr_value,expected_type,expected_data",
+ [
+ ("text_delta", "text", "Hello world", "text", "Hello world"),
+ ("text_delta", "text", "", "text", ""),
+ (
+ "thinking_delta",
+ "thinking",
+ "Let me think...",
+ "thinking",
+ "Let me think...",
+ ),
+ ("signature_delta", "signature", "sig-12345", "thinking", None),
+ ],
+)
+def test_content_block_delta_types(
+ provider, delta_type, delta_attr, attr_value, expected_type, expected_data
+):
+ """Different delta types should emit appropriate events."""
+ delta = Mock(spec=["type", delta_attr])
+ delta.type = delta_type
+ setattr(delta, delta_attr, attr_value)
+
+ event = Mock(spec=["type", "delta"])
+ event.type = "content_block_delta"
+ event.delta = delta
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event.type == expected_type
+ if expected_type == "thinking" and delta_type == "signature_delta":
+ assert stream_event.signature == attr_value
+ else:
+ assert stream_event.data == expected_data
+
+
+@pytest.mark.parametrize(
+ "current_tool,partial_json,expected_input",
+ [
+ (
+ {"id": "t1", "name": "search", "input": '{"query": "'},
+ 'test"}',
+ '{"query": "test"}',
+ ),
+ (
+ {"id": "t1", "name": "search", "input": '{"'},
+ 'key": "value"}',
+ '{"key": "value"}',
+ ),
+ (
+ {"id": "t1", "name": "search", "input": ""},
+ '{"query": "test"}',
+ '{"query": "test"}',
+ ),
+ ],
+)
+def test_content_block_delta_input_json_accumulation(
+ provider, current_tool, partial_json, expected_input
+):
+ """JSON delta should accumulate to tool input."""
+ delta = Mock(spec=["type", "partial_json"])
+ delta.type = "input_json_delta"
+ delta.partial_json = partial_json
+
+ event = Mock(spec=["type", "delta"])
+ event.type = "content_block_delta"
+ event.delta = delta
+
+ stream_event, tool_use = provider._handle_stream_event(event, current_tool)
+
+ assert stream_event is None
+ assert tool_use["input"] == expected_input
+
+
+def test_content_block_delta_input_json_without_tool(provider):
+ """JSON delta without tool context should return None."""
+ delta = Mock(spec=["type", "partial_json"])
+ delta.type = "input_json_delta"
+ delta.partial_json = '{"key": "value"}'
+
+ event = Mock(spec=["type", "delta"])
+ event.type = "content_block_delta"
+ event.delta = delta
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+ assert tool_use is None
+
+
+def test_content_block_delta_input_json_with_dict_input(provider):
+ """JSON delta shouldn't modify if input is already a dict."""
+ current_tool = {"id": "t1", "name": "search", "input": {"query": "test"}}
+
+ delta = Mock(spec=["type", "partial_json"])
+ delta.type = "input_json_delta"
+ delta.partial_json = ', "extra": "data"'
+
+ event = Mock(spec=["type", "delta"])
+ event.type = "content_block_delta"
+ event.delta = delta
+
+ stream_event, tool_use = provider._handle_stream_event(event, current_tool)
+
+ assert tool_use["input"] == {"query": "test"}
+
+
+@pytest.mark.parametrize(
+ "has_delta,delta_type",
+ [
+ (False, None),
+ (True, "unknown_delta"),
+ ],
+)
+def test_content_block_delta_ignored_cases(provider, has_delta, delta_type):
+ """Events without delta or with unknown types should be ignored."""
+ event = Mock(spec=["type", "delta"] if has_delta else ["type"])
+ event.type = "content_block_delta"
+
+ if has_delta:
+ delta = Mock(spec=["type"])
+ delta.type = delta_type
+ event.delta = delta
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+
+
+# Content Block Stop Tests
+
+
+@pytest.mark.parametrize(
+ "input_value,has_content_block,expected_input",
+ [
+ ("", False, {}),
+ (" \n\t ", False, {}),
+ ('{"invalid": json}', False, {}),
+ ('{"query": "test", "limit": 10}', False, {"query": "test", "limit": 10}),
+ (
+ '{"filters": {"type": "user", "status": ["active", "pending"]}, "limit": 100}',
+ False,
+ {
+ "filters": {"type": "user", "status": ["active", "pending"]},
+ "limit": 100,
+ },
+ ),
+ ("", True, {"query": "test"}),
+ ],
+)
+def test_content_block_stop_tool_finalization(
+ provider, input_value, has_content_block, expected_input
+):
+ """Tool stop should parse or use provided input correctly."""
+ current_tool = {"id": "t1", "name": "search", "input": input_value}
+
+ event = Mock(spec=["type", "content_block"] if has_content_block else ["type"])
+ event.type = "content_block_stop"
+
+ if has_content_block:
+ block = Mock(spec=["input"])
+ block.input = {"query": "test"}
+ event.content_block = block
+
+ stream_event, tool_use = provider._handle_stream_event(event, current_tool)
+
+ assert stream_event.type == "tool_use"
+ assert stream_event.data["input"] == expected_input
+ assert tool_use is None
+
+
+def test_content_block_stop_with_server_info(provider):
+ """Server tool info should be included in final event."""
+ current_tool = {
+ "id": "t1",
+ "name": "mcp_search",
+ "input": '{"q": "test"}',
+ "server_name": "mcp-server",
+ "is_server_call": True,
+ }
+
+ event = Mock(spec=["type"])
+ event.type = "content_block_stop"
+
+ stream_event, tool_use = provider._handle_stream_event(event, current_tool)
+
+ assert stream_event.data["server_name"] == "mcp-server"
+ assert stream_event.data["is_server_call"] is True
+
+
+def test_content_block_stop_without_tool(provider):
+ """Stop without current tool should return None."""
+ event = Mock(spec=["type"])
+ event.type = "content_block_stop"
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+ assert tool_use is None
+
+
+# Message Delta Tests
+
+
+def test_message_delta_max_tokens(provider):
+ """Max tokens stop reason should emit error."""
+ delta = Mock(spec=["stop_reason"])
+ delta.stop_reason = "max_tokens"
+
+ event = Mock(spec=["type", "delta"])
+ event.type = "message_delta"
+ event.delta = delta
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event.type == "error"
+ assert "Max tokens" in stream_event.data
+
+
+@pytest.mark.parametrize("stop_reason", ["end_turn", "stop_sequence", None])
+def test_message_delta_other_stop_reasons(provider, stop_reason):
+ """Other stop reasons should not emit error."""
+ delta = Mock(spec=["stop_reason"])
+ delta.stop_reason = stop_reason
+
+ event = Mock(spec=["type", "delta"])
+ event.type = "message_delta"
+ event.delta = delta
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+
+
+def test_message_delta_token_usage(provider):
+ """Token usage should be logged but not emitted."""
+ usage = Mock(
+ spec=[
+ "input_tokens",
+ "output_tokens",
+ "cache_creation_input_tokens",
+ "cache_read_input_tokens",
+ ]
+ )
+ usage.input_tokens = 100
+ usage.output_tokens = 50
+ usage.cache_creation_input_tokens = 10
+ usage.cache_read_input_tokens = 20
+
+ event = Mock(spec=["type", "usage"])
+ event.type = "message_delta"
+ event.usage = usage
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+
+
+def test_message_delta_empty(provider):
+ """Message delta without delta or usage should return None."""
+ event = Mock(spec=["type"])
+ event.type = "message_delta"
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+
+
+# Message Stop Tests
+
+
+@pytest.mark.parametrize(
+ "current_tool",
+ [
+ None,
+ {"id": "t1", "name": "search", "input": '{"incomplete'},
+ ],
+)
+def test_message_stop(provider, current_tool):
+ """Message stop should emit done regardless of incomplete tools."""
+ event = Mock(spec=["type"])
+ event.type = "message_stop"
+
+ stream_event, tool_use = provider._handle_stream_event(event, current_tool)
+
+ assert stream_event.type == "done"
+ assert tool_use is None
+
+
+# Error Handling Tests
+
+
+@pytest.mark.parametrize(
+ "has_error,error_value,expected_message",
+ [
+ (True, "API rate limit exceeded", "rate limit"),
+ (False, None, "Unknown error"),
+ ],
+)
+def test_error_events(provider, has_error, error_value, expected_message):
+ """Error events should emit error StreamEvent."""
+ event = Mock(spec=["type", "error"] if has_error else ["type"])
+ event.type = "error"
+ if has_error:
+ event.error = error_value
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event.type == "error"
+ assert expected_message in stream_event.data
+
+
+# Unknown Event Tests
+
+
+@pytest.mark.parametrize(
+ "event_type",
+ ["message_start", "future_event_type", None],
+)
+def test_unknown_or_ignored_events(provider, event_type):
+ """Unknown event types should be logged but not fail."""
+ if event_type is None:
+ event = Mock(spec=[])
+ else:
+ event = Mock(spec=["type"])
+ event.type = event_type
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+
+
+# State Transition Tests
+
+
+def test_complete_tool_call_sequence(provider):
+ """Simulate a complete tool call from start to finish."""
+ # Start
+ block = Mock(spec=["type", "id", "name", "input"])
+ block.type = "tool_use"
+ block.id = "tool-1"
+ block.name = "search"
+ block.input = None
+
+ event1 = Mock(spec=["type", "content_block"])
+ event1.type = "content_block_start"
+ event1.content_block = block
+
+ _, tool_use = provider._handle_stream_event(event1, None)
+ assert tool_use["input"] == ""
+
+ # Delta 1
+ delta1 = Mock(spec=["type", "partial_json"])
+ delta1.type = "input_json_delta"
+ delta1.partial_json = '{"query":'
+
+ event2 = Mock(spec=["type", "delta"])
+ event2.type = "content_block_delta"
+ event2.delta = delta1
+
+ _, tool_use = provider._handle_stream_event(event2, tool_use)
+ assert tool_use["input"] == '{"query":'
+
+ # Delta 2
+ delta2 = Mock(spec=["type", "partial_json"])
+ delta2.type = "input_json_delta"
+ delta2.partial_json = ' "test"}'
+
+ event3 = Mock(spec=["type", "delta"])
+ event3.type = "content_block_delta"
+ event3.delta = delta2
+
+ _, tool_use = provider._handle_stream_event(event3, tool_use)
+ assert tool_use["input"] == '{"query": "test"}'
+
+ # Stop
+ event4 = Mock(spec=["type"])
+ event4.type = "content_block_stop"
+
+ stream_event, tool_use = provider._handle_stream_event(event4, tool_use)
+
+ assert stream_event.type == "tool_use"
+ assert stream_event.data["input"] == {"query": "test"}
+ assert tool_use is None
+
+
+def test_text_and_thinking_mixed(provider):
+ """Text and thinking deltas should be handled independently."""
+ delta1 = Mock(spec=["type", "text"])
+ delta1.type = "text_delta"
+ delta1.text = "Answer: "
+
+ event1 = Mock(spec=["type", "delta"])
+ event1.type = "content_block_delta"
+ event1.delta = delta1
+
+ event1_result, _ = provider._handle_stream_event(event1, None)
+ assert event1_result.type == "text"
+
+ delta2 = Mock(spec=["type", "thinking"])
+ delta2.type = "thinking_delta"
+ delta2.thinking = "reasoning..."
+
+ event2 = Mock(spec=["type", "delta"])
+ event2.type = "content_block_delta"
+ event2.delta = delta2
+
+ event2_result, _ = provider._handle_stream_event(event2, None)
+ assert event2_result.type == "thinking"
diff --git a/tests/memory/common/llms/test_anthropic_provider.py b/tests/memory/common/llms/test_anthropic_provider.py
new file mode 100644
index 0000000..6c76863
--- /dev/null
+++ b/tests/memory/common/llms/test_anthropic_provider.py
@@ -0,0 +1,440 @@
+import pytest
+from unittest.mock import Mock, patch, MagicMock
+from PIL import Image
+
+from memory.common.llms.anthropic_provider import AnthropicProvider
+from memory.common.llms.base import (
+ Message,
+ MessageRole,
+ TextContent,
+ ImageContent,
+ ToolUseContent,
+ ToolResultContent,
+ ThinkingContent,
+ LLMSettings,
+ StreamEvent,
+)
+from memory.common.llms.tools import ToolDefinition
+
+
+@pytest.fixture
+def provider():
+ return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229")
+
+
+@pytest.fixture
+def thinking_provider():
+ return AnthropicProvider(
+ api_key="test-key", model="claude-opus-4", enable_thinking=True
+ )
+
+
+def test_initialization(provider):
+ assert provider.api_key == "test-key"
+ assert provider.model == "claude-3-opus-20240229"
+ assert provider.enable_thinking is False
+
+
+def test_client_lazy_loading(provider):
+ assert provider._client is None
+ client = provider.client
+ assert client is not None
+ assert provider._client is not None
+ # Second call should return same instance
+ assert provider.client is client
+
+
+def test_async_client_lazy_loading(provider):
+ assert provider._async_client is None
+ client = provider.async_client
+ assert client is not None
+ assert provider._async_client is not None
+
+
+@pytest.mark.parametrize(
+ "model, expected",
+ [
+ ("claude-opus-4", True),
+ ("claude-opus-4-1", True),
+ ("claude-sonnet-4-0", True),
+ ("claude-sonnet-3-7", True),
+ ("claude-sonnet-4-5", True),
+ ("claude-3-opus-20240229", False),
+ ("claude-3-sonnet-20240229", False),
+ ("gpt-4", False),
+ ],
+)
+def test_supports_thinking(model, expected):
+ provider = AnthropicProvider(api_key="test-key", model=model)
+ assert provider._supports_thinking() == expected
+
+
+def test_convert_text_content(provider):
+ content = TextContent(text="hello world")
+ result = provider._convert_text_content(content)
+ assert result == {"type": "text", "text": "hello world"}
+
+
+def test_convert_image_content(provider):
+ image = Image.new("RGB", (100, 100), color="red")
+ content = ImageContent(image=image)
+ result = provider._convert_image_content(content)
+
+ assert result["type"] == "image"
+ assert result["source"]["type"] == "base64"
+ assert result["source"]["media_type"] == "image/jpeg"
+ assert isinstance(result["source"]["data"], str)
+
+
+def test_should_include_message_filters_system(provider):
+ system_msg = Message(role=MessageRole.SYSTEM, content="system prompt")
+ user_msg = Message(role=MessageRole.USER, content="user message")
+
+ assert provider._should_include_message(system_msg) is False
+ assert provider._should_include_message(user_msg) is True
+
+
+@pytest.mark.parametrize(
+ "messages, expected_count",
+ [
+ ([Message(role=MessageRole.USER, content="test")], 1),
+ ([Message(role=MessageRole.SYSTEM, content="test")], 0),
+ (
+ [
+ Message(role=MessageRole.SYSTEM, content="system"),
+ Message(role=MessageRole.USER, content="user"),
+ ],
+ 1,
+ ),
+ ],
+)
+def test_convert_messages(provider, messages, expected_count):
+ result = provider._convert_messages(messages)
+ assert len(result) == expected_count
+
+
+def test_convert_tool(provider):
+ tool = ToolDefinition(
+ name="test_tool",
+ description="A test tool",
+ input_schema={"type": "object", "properties": {}},
+ function=lambda x: "result",
+ )
+ result = provider._convert_tool(tool)
+
+ assert result["name"] == "test_tool"
+ assert result["description"] == "A test tool"
+ assert result["input_schema"] == {"type": "object", "properties": {}}
+
+
+def test_build_request_kwargs_basic(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings(temperature=0.5, max_tokens=1000)
+
+ kwargs = provider._build_request_kwargs(messages, None, None, settings)
+
+ assert kwargs["model"] == "claude-3-opus-20240229"
+ assert kwargs["temperature"] == 0.5
+ assert kwargs["max_tokens"] == 1000
+ assert len(kwargs["messages"]) == 1
+
+
+def test_build_request_kwargs_with_system_prompt(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings()
+
+ kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
+
+ assert kwargs["system"] == "system prompt"
+
+
+def test_build_request_kwargs_with_tools(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ tools = [
+ ToolDefinition(
+ name="test",
+ description="test",
+ input_schema={},
+ function=lambda x: "result",
+ )
+ ]
+ settings = LLMSettings()
+
+ kwargs = provider._build_request_kwargs(messages, None, tools, settings)
+
+ assert "tools" in kwargs
+ assert len(kwargs["tools"]) == 1
+
+
+def test_build_request_kwargs_with_thinking(thinking_provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings(max_tokens=5000)
+
+ kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
+
+ assert "thinking" in kwargs
+ assert kwargs["thinking"]["type"] == "enabled"
+ assert kwargs["thinking"]["budget_tokens"] == 3976
+ assert kwargs["temperature"] == 1.0
+ assert "top_p" not in kwargs
+
+
+def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings(max_tokens=1000)
+
+ kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
+
+ # Shouldn't enable thinking if not enough tokens
+ assert "thinking" not in kwargs
+
+
+def test_handle_stream_event_text_delta(provider):
+ event = Mock(
+ type="content_block_delta",
+ delta=Mock(type="text_delta", text="hello"),
+ )
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is not None
+ assert stream_event.type == "text"
+ assert stream_event.data == "hello"
+ assert tool_use is None
+
+
+def test_handle_stream_event_thinking_delta(provider):
+ event = Mock(
+ type="content_block_delta",
+ delta=Mock(type="thinking_delta", thinking="reasoning..."),
+ )
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is not None
+ assert stream_event.type == "thinking"
+ assert stream_event.data == "reasoning..."
+
+
+def test_handle_stream_event_tool_use_start(provider):
+ block = Mock(spec=["type", "id", "name", "input"])
+ block.type = "tool_use"
+ block.id = "tool-1"
+ block.name = "test_tool"
+ block.input = {}
+
+ event = Mock(spec=["type", "content_block"])
+ event.type = "content_block_start"
+ event.content_block = block
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is None
+ assert tool_use is not None
+ assert tool_use["id"] == "tool-1"
+ assert tool_use["name"] == "test_tool"
+ assert tool_use["input"] == {}
+
+
+def test_handle_stream_event_tool_input_delta(provider):
+ current_tool = {"id": "tool-1", "name": "test", "input": '{"ke'}
+ event = Mock(
+ type="content_block_delta",
+ delta=Mock(type="input_json_delta", partial_json='y": "val'),
+ )
+
+ stream_event, tool_use = provider._handle_stream_event(event, current_tool)
+
+ assert stream_event is None
+ assert tool_use["input"] == '{"key": "val'
+
+
+def test_handle_stream_event_tool_use_complete(provider):
+ current_tool = {
+ "id": "tool-1",
+ "name": "test_tool",
+ "input": '{"key": "value"}',
+ }
+ event = Mock(
+ type="content_block_stop",
+ content_block=Mock(input={"key": "value"}),
+ )
+
+ stream_event, tool_use = provider._handle_stream_event(event, current_tool)
+
+ assert stream_event is not None
+ assert stream_event.type == "tool_use"
+ assert stream_event.data["id"] == "tool-1"
+ assert stream_event.data["name"] == "test_tool"
+ assert stream_event.data["input"] == {"key": "value"}
+ assert tool_use is None
+
+
+def test_handle_stream_event_message_stop(provider):
+ event = Mock(type="message_stop")
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is not None
+ assert stream_event.type == "done"
+ assert tool_use is None
+
+
+def test_handle_stream_event_error(provider):
+ event = Mock(type="error", error="API error")
+
+ stream_event, tool_use = provider._handle_stream_event(event, None)
+
+ assert stream_event is not None
+ assert stream_event.type == "error"
+ assert "API error" in stream_event.data
+
+
+def test_generate_basic(provider, mock_anthropic_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ # Mock the response properly
+ mock_block = Mock(spec=["type", "text"])
+ mock_block.type = "text"
+ mock_block.text = "test summary"
+
+ mock_response = Mock(spec=["content"])
+ mock_response.content = [mock_block]
+
+ provider.client.messages.create.return_value = mock_response
+
+ result = provider.generate(messages)
+
+ assert result == "test summary"
+ provider.client.messages.create.assert_called_once()
+
+
+def test_stream_basic(provider, mock_anthropic_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ events = list(provider.stream(messages))
+
+ # Should get text event and done event
+ assert len(events) > 0
+ assert any(e.type == "text" for e in events)
+ provider.client.messages.stream.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_agenerate_basic(provider, mock_anthropic_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ result = await provider.agenerate(messages)
+
+ assert result == "test summary"
+ provider.async_client.messages.create.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_astream_basic(provider, mock_anthropic_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ events = []
+ async for event in provider.astream(messages):
+ events.append(event)
+
+ assert len(events) > 0
+ assert any(e.type == "text" for e in events)
+
+
+def test_convert_message_sorts_thinking_content(provider):
+ """Thinking content should be sorted so non-thinking comes before thinking."""
+ message = Message.assistant(
+ ThinkingContent(thinking="reasoning", signature="sig"),
+ TextContent(text="response"),
+ )
+
+ result = provider._convert_message(message)
+
+ assert result["role"] == "assistant"
+ # The sort key (x["type"] != "thinking") sorts thinking type to beginning
+ # because "thinking" != "thinking" is False, which sorts before True
+ content_types = [c["type"] for c in result["content"]]
+ assert "text" in content_types
+ assert "thinking" in content_types
+ # Verify thinking comes before non-thinking (sorted by key)
+ thinking_idx = content_types.index("thinking")
+ text_idx = content_types.index("text")
+ assert thinking_idx < text_idx
+
+
+def test_execute_tool_success(provider):
+ tool_call = {"id": "t1", "name": "test", "input": {"arg": "value"}}
+ tools = {
+ "test": ToolDefinition(
+ name="test",
+ description="test",
+ input_schema={},
+ function=lambda x: f"result: {x['arg']}",
+ )
+ }
+
+ result = provider.execute_tool(tool_call, tools)
+
+ assert result.tool_use_id == "t1"
+ assert result.content == "result: value"
+ assert result.is_error is False
+
+
+def test_execute_tool_missing_name(provider):
+ tool_call = {"id": "t1", "input": {}}
+ tools = {}
+
+ result = provider.execute_tool(tool_call, tools)
+
+ assert result.tool_use_id == "t1"
+ assert "missing" in result.content.lower()
+ assert result.is_error is True
+
+
+def test_execute_tool_not_found(provider):
+ tool_call = {"id": "t1", "name": "nonexistent", "input": {}}
+ tools = {}
+
+ result = provider.execute_tool(tool_call, tools)
+
+ assert result.tool_use_id == "t1"
+ assert "not found" in result.content.lower()
+ assert result.is_error is True
+
+
+def test_execute_tool_exception(provider):
+ tool_call = {"id": "t1", "name": "test", "input": {}}
+ tools = {
+ "test": ToolDefinition(
+ name="test",
+ description="test",
+ input_schema={},
+ function=lambda x: 1 / 0, # Raises ZeroDivisionError
+ )
+ }
+
+ result = provider.execute_tool(tool_call, tools)
+
+ assert result.tool_use_id == "t1"
+ assert result.is_error is True
+ assert "division" in result.content.lower()
+
+
+def test_encode_image(provider):
+ image = Image.new("RGB", (10, 10), color="blue")
+
+ encoded = provider.encode_image(image)
+
+ assert isinstance(encoded, str)
+ assert len(encoded) > 0
+
+
+def test_encode_image_rgba(provider):
+ """RGBA images should be converted to RGB."""
+ image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128))
+
+ encoded = provider.encode_image(image)
+
+ assert isinstance(encoded, str)
+ assert len(encoded) > 0
diff --git a/tests/memory/common/llms/test_base.py b/tests/memory/common/llms/test_base.py
new file mode 100644
index 0000000..d0663d2
--- /dev/null
+++ b/tests/memory/common/llms/test_base.py
@@ -0,0 +1,270 @@
+import pytest
+from PIL import Image
+
+from memory.common.llms.base import (
+ Message,
+ MessageRole,
+ TextContent,
+ ImageContent,
+ ToolUseContent,
+ ToolResultContent,
+ ThinkingContent,
+ LLMSettings,
+ StreamEvent,
+ create_provider,
+)
+from memory.common.llms.anthropic_provider import AnthropicProvider
+from memory.common.llms.openai_provider import OpenAIProvider
+from memory.common import settings
+
+
+def test_message_role_enum():
+ assert MessageRole.SYSTEM == "system"
+ assert MessageRole.USER == "user"
+ assert MessageRole.ASSISTANT == "assistant"
+ assert MessageRole.TOOL == "tool"
+
+
+def test_text_content_creation():
+ content = TextContent(text="hello")
+ assert content.type == "text"
+ assert content.text == "hello"
+ assert content.valid
+
+
+def test_text_content_to_dict():
+ content = TextContent(text="hello")
+ result = content.to_dict()
+ assert result == {"type": "text", "text": "hello"}
+
+
+def test_text_content_empty_invalid():
+ content = TextContent(text="")
+ assert not content.valid
+
+
+def test_image_content_creation():
+ image = Image.new("RGB", (10, 10))
+ content = ImageContent(image=image)
+ assert content.type == "image"
+ assert content.image == image
+ assert content.valid
+
+
+def test_image_content_with_detail():
+ image = Image.new("RGB", (10, 10))
+ content = ImageContent(image=image, detail="high")
+ assert content.detail == "high"
+
+
+def test_tool_use_content_creation():
+ content = ToolUseContent(id="t1", name="test_tool", input={"arg": "value"})
+ assert content.type == "tool_use"
+ assert content.id == "t1"
+ assert content.name == "test_tool"
+ assert content.input == {"arg": "value"}
+ assert content.valid
+
+
+def test_tool_use_content_to_dict():
+ content = ToolUseContent(id="t1", name="test", input={"key": "val"})
+ result = content.to_dict()
+ assert result == {
+ "type": "tool_use",
+ "id": "t1",
+ "name": "test",
+ "input": {"key": "val"},
+ }
+
+
+def test_tool_result_content_creation():
+ content = ToolResultContent(
+ tool_use_id="t1",
+ content="result",
+ is_error=False,
+ )
+ assert content.type == "tool_result"
+ assert content.tool_use_id == "t1"
+ assert content.content == "result"
+ assert not content.is_error
+ assert content.valid
+
+
+def test_tool_result_content_with_error():
+ content = ToolResultContent(
+ tool_use_id="t1",
+ content="error message",
+ is_error=True,
+ )
+ assert content.is_error
+
+
+def test_thinking_content_creation():
+ content = ThinkingContent(thinking="reasoning...", signature="sig")
+ assert content.type == "thinking"
+ assert content.thinking == "reasoning..."
+ assert content.signature == "sig"
+ assert content.valid
+
+
+def test_thinking_content_invalid_without_signature():
+ content = ThinkingContent(thinking="reasoning...")
+ assert not content.valid
+
+
+def test_message_simple_string_content():
+ msg = Message(role=MessageRole.USER, content="hello")
+ assert msg.role == MessageRole.USER
+ assert msg.content == "hello"
+
+
+def test_message_list_content():
+ content_list = [TextContent(text="hello"), TextContent(text="world")]
+ msg = Message(role=MessageRole.USER, content=content_list)
+ assert msg.role == MessageRole.USER
+ assert len(msg.content) == 2
+
+
+def test_message_to_dict_string():
+ msg = Message(role=MessageRole.USER, content="hello")
+ result = msg.to_dict()
+ assert result == {"role": "user", "content": "hello"}
+
+
+def test_message_to_dict_list():
+ msg = Message(
+ role=MessageRole.USER,
+ content=[TextContent(text="hello"), TextContent(text="world")],
+ )
+ result = msg.to_dict()
+ assert result["role"] == "user"
+ assert len(result["content"]) == 2
+ assert result["content"][0] == {"type": "text", "text": "hello"}
+
+
+def test_message_assistant_factory():
+ msg = Message.assistant(
+ TextContent(text="response"),
+ ToolUseContent(id="t1", name="tool", input={}),
+ )
+ assert msg.role == MessageRole.ASSISTANT
+ assert len(msg.content) == 2
+
+
+def test_message_assistant_filters_invalid_content():
+ msg = Message.assistant(
+ TextContent(text="valid"),
+ TextContent(text=""), # Invalid - empty
+ )
+ assert len(msg.content) == 1
+ assert msg.content[0].text == "valid"
+
+
+def test_message_user_factory():
+ msg = Message.user(text="hello")
+ assert msg.role == MessageRole.USER
+ assert len(msg.content) == 1
+ assert isinstance(msg.content[0], TextContent)
+
+
+def test_message_user_with_tool_result():
+ tool_result = ToolResultContent(tool_use_id="t1", content="result")
+ msg = Message.user(text="hello", tool_result=tool_result)
+ assert len(msg.content) == 2
+
+
+def test_stream_event_creation():
+ event = StreamEvent(type="text", data="hello")
+ assert event.type == "text"
+ assert event.data == "hello"
+
+
+def test_stream_event_with_signature():
+ event = StreamEvent(type="thinking", signature="sig123")
+ assert event.signature == "sig123"
+
+
+def test_llm_settings_defaults():
+ settings = LLMSettings()
+ assert settings.temperature == 0.7
+ assert settings.max_tokens == 2048
+ assert settings.top_p is None
+ assert settings.stop_sequences is None
+ assert settings.stream is False
+
+
+def test_llm_settings_custom():
+ settings = LLMSettings(
+ temperature=0.5,
+ max_tokens=1000,
+ top_p=0.9,
+ stop_sequences=["STOP"],
+ stream=True,
+ )
+ assert settings.temperature == 0.5
+ assert settings.max_tokens == 1000
+ assert settings.top_p == 0.9
+ assert settings.stop_sequences == ["STOP"]
+ assert settings.stream is True
+
+
+def test_create_provider_anthropic():
+ provider = create_provider(
+ model="anthropic/claude-3-opus-20240229",
+ api_key="test-key",
+ )
+ assert isinstance(provider, AnthropicProvider)
+ assert provider.model == "claude-3-opus-20240229"
+
+
+def test_create_provider_openai():
+ provider = create_provider(
+ model="openai/gpt-4o",
+ api_key="test-key",
+ )
+ assert isinstance(provider, OpenAIProvider)
+ assert provider.model == "gpt-4o"
+
+
+def test_create_provider_unknown_raises():
+ with pytest.raises(ValueError, match="Unknown provider"):
+ create_provider(model="unknown/model", api_key="test-key")
+
+
+def test_create_provider_uses_default_model():
+ """If no model provided, should use SUMMARIZER_MODEL from settings."""
+ provider = create_provider(api_key="test-key")
+ # Should create a provider (type depends on settings.SUMMARIZER_MODEL)
+ assert provider is not None
+
+
+def test_create_provider_anthropic_with_thinking():
+ provider = create_provider(
+ model="anthropic/claude-opus-4",
+ api_key="test-key",
+ enable_thinking=True,
+ )
+ assert isinstance(provider, AnthropicProvider)
+ assert provider.enable_thinking is True
+
+
+def test_create_provider_missing_anthropic_key():
+ # Temporarily clear the API key from settings
+ original_key = settings.ANTHROPIC_API_KEY
+ try:
+ settings.ANTHROPIC_API_KEY = ""
+ with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"):
+ create_provider(model="anthropic/claude-3-opus-20240229")
+ finally:
+ settings.ANTHROPIC_API_KEY = original_key
+
+
+def test_create_provider_missing_openai_key():
+ # Temporarily clear the API key from settings
+ original_key = settings.OPENAI_API_KEY
+ try:
+ settings.OPENAI_API_KEY = ""
+ with pytest.raises(ValueError, match="OPENAI_API_KEY"):
+ create_provider(model="openai/gpt-4o")
+ finally:
+ settings.OPENAI_API_KEY = original_key
diff --git a/tests/memory/common/llms/test_openai_event_parsing.py b/tests/memory/common/llms/test_openai_event_parsing.py
new file mode 100644
index 0000000..2574715
--- /dev/null
+++ b/tests/memory/common/llms/test_openai_event_parsing.py
@@ -0,0 +1,478 @@
+"""Comprehensive tests for OpenAI stream chunk parsing."""
+
+import pytest
+from unittest.mock import Mock
+
+from memory.common.llms.openai_provider import OpenAIProvider
+from memory.common.llms.base import StreamEvent
+
+
+@pytest.fixture
+def provider():
+ return OpenAIProvider(api_key="test-key", model="gpt-4o")
+
+
+# Text Content Tests
+
+
+@pytest.mark.parametrize(
+ "content,expected_events",
+ [
+ ("Hello", 1),
+ ("", 0), # Empty string is falsy
+ (None, 0),
+ ("Line 1\nLine 2\nLine 3", 1),
+ ("Hello δΈη π", 1),
+ ],
+)
+def test_text_content(provider, content, expected_events):
+ """Text content should emit text events appropriately."""
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = content
+ delta.tool_calls = None
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == expected_events
+ if expected_events > 0:
+ assert events[0].type == "text"
+ assert events[0].data == content
+ assert tool_call is None
+
+
+# Tool Call Start Tests
+
+
+def test_new_tool_call_basic(provider):
+ """New tool call should initialize state."""
+ function = Mock(spec=["name", "arguments"])
+ function.name = "search"
+ function.arguments = ""
+
+ tool = Mock(spec=["id", "function"])
+ tool.id = "call_123"
+ tool.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = [tool]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 0
+ assert tool_call == {"id": "call_123", "name": "search", "arguments": ""}
+
+
+@pytest.mark.parametrize(
+ "name,arguments,expected_name,expected_args",
+ [
+ ("calculate", '{"operation":', "calculate", '{"operation":'),
+ (None, "", "", ""),
+ ("test", None, "test", ""),
+ ],
+)
+def test_new_tool_call_variations(
+ provider, name, arguments, expected_name, expected_args
+):
+ """Tool calls with various name/argument combinations."""
+ function = Mock(spec=["name", "arguments"])
+ function.name = name
+ function.arguments = arguments
+
+ tool = Mock(spec=["id", "function"])
+ tool.id = "call_123"
+ tool.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = [tool]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert tool_call["name"] == expected_name
+ assert tool_call["arguments"] == expected_args
+
+
+def test_new_tool_call_replaces_previous(provider):
+ """New tool call should finalize and replace previous."""
+ current = {"id": "call_old", "name": "old_tool", "arguments": '{"arg": "value"}'}
+
+ function = Mock(spec=["name", "arguments"])
+ function.name = "new_tool"
+ function.arguments = ""
+
+ tool = Mock(spec=["id", "function"])
+ tool.id = "call_new"
+ tool.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = [tool]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, current)
+
+ assert len(events) == 1
+ assert events[0].type == "tool_use"
+ assert events[0].data["id"] == "call_old"
+ assert events[0].data["input"] == {"arg": "value"}
+ assert tool_call["id"] == "call_new"
+
+
+# Tool Call Continuation Tests
+
+
+@pytest.mark.parametrize(
+ "initial_args,new_args,expected_args",
+ [
+ ('{"query": "', 'test query"}', '{"query": "test query"}'),
+ ('{"query"', ': "value"}', '{"query": "value"}'),
+ ("", '{"full": "json"}', '{"full": "json"}'),
+ ('{"partial"', "", '{"partial"'), # Empty doesn't accumulate
+ ],
+)
+def test_tool_call_argument_accumulation(
+ provider, initial_args, new_args, expected_args
+):
+ """Arguments should accumulate correctly."""
+ current = {"id": "call_123", "name": "search", "arguments": initial_args}
+
+ function = Mock(spec=["name", "arguments"])
+ function.name = None
+ function.arguments = new_args
+
+ tool = Mock(spec=["id", "function"])
+ tool.id = None
+ tool.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = [tool]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, current)
+
+ assert len(events) == 0
+ assert tool_call["arguments"] == expected_args
+
+
+def test_tool_call_accumulation_without_current_tool(provider):
+ """Arguments without current tool should be ignored."""
+ function = Mock(spec=["name", "arguments"])
+ function.name = None
+ function.arguments = '{"arg": "value"}'
+
+ tool = Mock(spec=["id", "function"])
+ tool.id = None
+ tool.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = [tool]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 0
+ assert tool_call is None
+
+
+def test_incremental_json_building(provider):
+ """Test realistic incremental JSON building across multiple chunks."""
+ current = {"id": "c1", "name": "search", "arguments": ""}
+
+ increments = ['{"', 'query":', ' "test"}']
+ expected_states = ['{"', '{"query":', '{"query": "test"}']
+
+ for increment, expected in zip(increments, expected_states):
+ function = Mock(spec=["name", "arguments"])
+ function.name = None
+ function.arguments = increment
+
+ tool = Mock(spec=["id", "function"])
+ tool.id = None
+ tool.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = [tool]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ _, current = provider._handle_stream_chunk(chunk, current)
+ assert current["arguments"] == expected
+
+
+# Finish Reason Tests
+
+
+def test_finish_reason_without_tool(provider):
+ """Stop finish without tool should not emit events."""
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = None
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = "stop"
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 0
+ assert tool_call is None
+
+
+@pytest.mark.parametrize(
+ "arguments,expected_input",
+ [
+ ('{"query": "test"}', {"query": "test"}),
+ ('{"invalid": json}', {}),
+ ("", {}),
+ ],
+)
+def test_finish_reason_with_tool(provider, arguments, expected_input):
+ """Finish with tool call should finalize and emit."""
+ current = {"id": "call_123", "name": "search", "arguments": arguments}
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = None
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = "tool_calls"
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, current)
+
+ assert len(events) == 1
+ assert events[0].type == "tool_use"
+ assert events[0].data["id"] == "call_123"
+ assert events[0].data["input"] == expected_input
+ assert tool_call is None
+
+
+@pytest.mark.parametrize("reason", ["stop", "length", "content_filter", "tool_calls"])
+def test_various_finish_reasons(provider, reason):
+ """Various finish reasons with active tool should finalize."""
+ current = {"id": "call_123", "name": "test", "arguments": '{"a": 1}'}
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = None
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = reason
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, current)
+
+ assert len(events) == 1
+ assert tool_call is None
+
+
+# Edge Cases Tests
+
+
+def test_empty_choices(provider):
+ """Empty choices list should return empty events."""
+ chunk = Mock(spec=["choices"])
+ chunk.choices = []
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 0
+ assert tool_call is None
+
+
+def test_none_choices(provider):
+ """None choices should be handled gracefully."""
+ chunk = Mock(spec=["choices"])
+ chunk.choices = None
+
+ try:
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+ assert len(events) == 0
+ except (TypeError, AttributeError):
+ pass # Also acceptable for malformed input
+
+
+def test_multiple_chunks_in_sequence(provider):
+ """Test processing multiple chunks sequentially."""
+ # Chunk 1: Start
+ function1 = Mock(spec=["name", "arguments"])
+ function1.name = "search"
+ function1.arguments = ""
+
+ tool1 = Mock(spec=["id", "function"])
+ tool1.id = "call_1"
+ tool1.function = function1
+
+ delta1 = Mock(spec=["content", "tool_calls"])
+ delta1.content = None
+ delta1.tool_calls = [tool1]
+
+ choice1 = Mock(spec=["delta", "finish_reason"])
+ choice1.delta = delta1
+ choice1.finish_reason = None
+
+ chunk1 = Mock(spec=["choices"])
+ chunk1.choices = [choice1]
+
+ events1, state = provider._handle_stream_chunk(chunk1, None)
+ assert len(events1) == 0
+ assert state is not None
+
+ # Chunk 2: Args
+ function2 = Mock(spec=["name", "arguments"])
+ function2.name = None
+ function2.arguments = '{"q": "test"}'
+
+ tool2 = Mock(spec=["id", "function"])
+ tool2.id = None
+ tool2.function = function2
+
+ delta2 = Mock(spec=["content", "tool_calls"])
+ delta2.content = None
+ delta2.tool_calls = [tool2]
+
+ choice2 = Mock(spec=["delta", "finish_reason"])
+ choice2.delta = delta2
+ choice2.finish_reason = None
+
+ chunk2 = Mock(spec=["choices"])
+ chunk2.choices = [choice2]
+
+ events2, state = provider._handle_stream_chunk(chunk2, state)
+ assert len(events2) == 0
+ assert state["arguments"] == '{"q": "test"}'
+
+ # Chunk 3: Finish
+ delta3 = Mock(spec=["content", "tool_calls"])
+ delta3.content = None
+ delta3.tool_calls = None
+
+ choice3 = Mock(spec=["delta", "finish_reason"])
+ choice3.delta = delta3
+ choice3.finish_reason = "stop"
+
+ chunk3 = Mock(spec=["choices"])
+ chunk3.choices = [choice3]
+
+ events3, state = provider._handle_stream_chunk(chunk3, state)
+ assert len(events3) == 1
+ assert events3[0].type == "tool_use"
+ assert state is None
+
+
+def test_text_and_tool_calls_mixed(provider):
+ """Text content should be emitted before tool initialization."""
+ function = Mock(spec=["name", "arguments"])
+ function.name = "search"
+ function.arguments = ""
+
+ tool = Mock(spec=["id", "function"])
+ tool.id = "call_1"
+ tool.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = "Let me search for that."
+ delta.tool_calls = [tool]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 1
+ assert events[0].type == "text"
+ assert events[0].data == "Let me search for that."
+ assert tool_call is not None
+
+
+# JSON Parsing Tests
+
+
+@pytest.mark.parametrize(
+ "arguments,expected_input",
+ [
+ ('{"key": "value", "num": 42}', {"key": "value", "num": 42}),
+ ("{}", {}),
+ (
+ '{"user": {"name": "John", "tags": ["a", "b"]}, "count": 10}',
+ {"user": {"name": "John", "tags": ["a", "b"]}, "count": 10},
+ ),
+ ('{"invalid": json}', {}),
+ ('{"key": "val', {}),
+ ("", {}),
+ ('{"text": "Hello δΈη π"}', {"text": "Hello δΈη π"}),
+ (
+ '{"text": "Line 1\\nLine 2\\t\\tTabbed"}',
+ {"text": "Line 1\nLine 2\t\tTabbed"},
+ ),
+ ],
+)
+def test_json_parsing(provider, arguments, expected_input):
+ """Various JSON inputs should be parsed correctly."""
+ tool_call = {"id": "c1", "name": "test", "arguments": arguments}
+
+ result = provider._parse_and_finalize_tool_call(tool_call)
+
+ assert result["input"] == expected_input
+ assert "arguments" not in result
diff --git a/tests/memory/common/llms/test_openai_provider.py b/tests/memory/common/llms/test_openai_provider.py
new file mode 100644
index 0000000..896aa96
--- /dev/null
+++ b/tests/memory/common/llms/test_openai_provider.py
@@ -0,0 +1,561 @@
+import pytest
+from unittest.mock import Mock
+from PIL import Image
+
+from memory.common.llms.openai_provider import OpenAIProvider
+from memory.common.llms.base import (
+ Message,
+ MessageRole,
+ TextContent,
+ ImageContent,
+ ToolUseContent,
+ ToolResultContent,
+ LLMSettings,
+ StreamEvent,
+)
+from memory.common.llms.tools import ToolDefinition
+
+
+@pytest.fixture
+def provider():
+ return OpenAIProvider(api_key="test-key", model="gpt-4o")
+
+
+@pytest.fixture
+def reasoning_provider():
+ return OpenAIProvider(api_key="test-key", model="o1-preview")
+
+
+def test_initialization(provider):
+ assert provider.api_key == "test-key"
+ assert provider.model == "gpt-4o"
+
+
+def test_client_lazy_loading(provider):
+ assert provider._client is None
+ client = provider.client
+ assert client is not None
+ assert provider._client is not None
+
+
+def test_async_client_lazy_loading(provider):
+ assert provider._async_client is None
+ client = provider.async_client
+ assert client is not None
+ assert provider._async_client is not None
+
+
+@pytest.mark.parametrize(
+ "model, expected",
+ [
+ ("gpt-4o", False),
+ ("o1-preview", True),
+ ("o1-mini", True),
+ ("gpt-4-turbo", True),
+ ("gpt-3.5-turbo", True),
+ ],
+)
+def test_is_reasoning_model(model, expected):
+ provider = OpenAIProvider(api_key="test-key", model=model)
+ assert provider._is_reasoning_model() == expected
+
+
+def test_convert_text_content(provider):
+ content = TextContent(text="hello world")
+ result = provider._convert_text_content(content)
+ assert result == {"type": "text", "text": "hello world"}
+
+
+def test_convert_image_content(provider):
+ image = Image.new("RGB", (100, 100), color="red")
+ content = ImageContent(image=image)
+ result = provider._convert_image_content(content)
+
+ assert result["type"] == "image_url"
+ assert "image_url" in result
+ assert result["image_url"]["url"].startswith("data:image/jpeg;base64,")
+
+
+def test_convert_image_content_with_detail(provider):
+ image = Image.new("RGB", (100, 100), color="red")
+ content = ImageContent(image=image, detail="high")
+ result = provider._convert_image_content(content)
+
+ assert result["image_url"]["detail"] == "high"
+
+
+def test_convert_tool_use_content(provider):
+ content = ToolUseContent(
+ id="t1",
+ name="test_tool",
+ input={"arg": "value"},
+ )
+ result = provider._convert_tool_use_content(content)
+
+ assert result["id"] == "t1"
+ assert result["type"] == "function"
+ assert result["function"]["name"] == "test_tool"
+ assert '{"arg": "value"}' in result["function"]["arguments"]
+
+
+def test_convert_tool_result_content(provider):
+ content = ToolResultContent(
+ tool_use_id="t1",
+ content="result content",
+ is_error=False,
+ )
+ result = provider._convert_tool_result_content(content)
+
+ assert result["role"] == "tool"
+ assert result["tool_call_id"] == "t1"
+ assert result["content"] == "result content"
+
+
+def test_convert_messages_simple(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ result = provider._convert_messages(messages)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == "test"
+
+
+def test_convert_messages_with_tool_result(provider):
+ """Tool results should become separate messages with 'tool' role."""
+ messages = [
+ Message(
+ role=MessageRole.USER,
+ content=[ToolResultContent(tool_use_id="t1", content="result")],
+ )
+ ]
+ result = provider._convert_messages(messages)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "tool"
+ assert result[0]["tool_call_id"] == "t1"
+
+
+def test_convert_messages_with_tool_use(provider):
+ """Tool use content should become tool_calls field."""
+ messages = [
+ Message.assistant(
+ TextContent(text="thinking..."),
+ ToolUseContent(id="t1", name="test", input={}),
+ )
+ ]
+ result = provider._convert_messages(messages)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "assistant"
+ assert "tool_calls" in result[0]
+ assert len(result[0]["tool_calls"]) == 1
+
+
+def test_convert_messages_mixed_content(provider):
+ """Messages with both text and tool results should be split."""
+ messages = [
+ Message(
+ role=MessageRole.USER,
+ content=[
+ TextContent(text="user text"),
+ ToolResultContent(tool_use_id="t1", content="result"),
+ ],
+ )
+ ]
+ result = provider._convert_messages(messages)
+
+ # Should create two messages: one user message and one tool message
+ assert len(result) == 2
+ assert result[0]["role"] == "tool"
+ assert result[1]["role"] == "user"
+
+
+def test_convert_tools(provider):
+ tools = [
+ ToolDefinition(
+ name="test_tool",
+ description="A test tool",
+ input_schema={"type": "object", "properties": {"arg": {"type": "string"}}},
+ function=lambda x: "result",
+ )
+ ]
+ result = provider._convert_tools(tools)
+
+ assert len(result) == 1
+ assert result[0]["type"] == "function"
+ assert result[0]["function"]["name"] == "test_tool"
+ assert result[0]["function"]["description"] == "A test tool"
+ assert result[0]["function"]["parameters"] == tools[0].input_schema
+
+
+def test_build_request_kwargs_basic(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings(temperature=0.5, max_tokens=1000)
+
+ kwargs = provider._build_request_kwargs(messages, None, None, settings)
+
+ assert kwargs["model"] == "gpt-4o"
+ assert kwargs["temperature"] == 0.5
+ assert kwargs["max_tokens"] == 1000
+ assert len(kwargs["messages"]) == 1
+
+
+def test_build_request_kwargs_with_system_prompt_standard_model(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings()
+
+ kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
+
+ # For gpt-4o, system prompt becomes system message
+ assert kwargs["messages"][0]["role"] == "system"
+ assert kwargs["messages"][0]["content"] == "system prompt"
+
+
+def test_build_request_kwargs_with_system_prompt_reasoning_model(
+ reasoning_provider,
+):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings()
+
+ kwargs = reasoning_provider._build_request_kwargs(
+ messages, "system prompt", None, settings
+ )
+
+ # For o1 models, system prompt becomes developer message
+ assert kwargs["messages"][0]["role"] == "developer"
+ assert kwargs["messages"][0]["content"] == "system prompt"
+
+
+def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens(
+ reasoning_provider,
+):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings(max_tokens=2000)
+
+ kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
+
+ # Reasoning models use max_completion_tokens
+ assert "max_completion_tokens" in kwargs
+ assert kwargs["max_completion_tokens"] == 2000
+ assert "max_tokens" not in kwargs
+
+
+def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings(temperature=0.7)
+
+ kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
+
+ # Reasoning models don't support temperature
+ assert "temperature" not in kwargs
+ assert "top_p" not in kwargs
+
+
+def test_build_request_kwargs_with_tools(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ tools = [
+ ToolDefinition(
+ name="test",
+ description="test",
+ input_schema={},
+ function=lambda x: "result",
+ )
+ ]
+ settings = LLMSettings()
+
+ kwargs = provider._build_request_kwargs(messages, None, tools, settings)
+
+ assert "tools" in kwargs
+ assert len(kwargs["tools"]) == 1
+ assert kwargs["tool_choice"] == "auto"
+
+
+def test_build_request_kwargs_with_stream(provider):
+ messages = [Message(role=MessageRole.USER, content="test")]
+ settings = LLMSettings()
+
+ kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True)
+
+ assert kwargs["stream"] is True
+
+
+def test_parse_and_finalize_tool_call(provider):
+ tool_call = {
+ "id": "t1",
+ "name": "test",
+ "arguments": '{"key": "value"}',
+ }
+
+ result = provider._parse_and_finalize_tool_call(tool_call)
+
+ assert result["id"] == "t1"
+ assert result["name"] == "test"
+ assert result["input"] == {"key": "value"}
+ assert "arguments" not in result
+
+
+def test_parse_and_finalize_tool_call_invalid_json(provider):
+ tool_call = {
+ "id": "t1",
+ "name": "test",
+ "arguments": '{"invalid json',
+ }
+
+ result = provider._parse_and_finalize_tool_call(tool_call)
+
+ # Should default to empty dict on parse error
+ assert result["input"] == {}
+
+
+def test_handle_stream_chunk_text_content(provider):
+ chunk = Mock(
+ choices=[
+ Mock(
+ delta=Mock(content="hello", tool_calls=None),
+ finish_reason=None,
+ )
+ ]
+ )
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 1
+ assert events[0].type == "text"
+ assert events[0].data == "hello"
+ assert tool_call is None
+
+
+def test_handle_stream_chunk_tool_call_start(provider):
+ function = Mock(spec=["name", "arguments"])
+ function.name = "test_tool"
+ function.arguments = ""
+
+ tool_call_mock = Mock(spec=["id", "function"])
+ tool_call_mock.id = "t1"
+ tool_call_mock.function = function
+
+ delta = Mock(spec=["content", "tool_calls"])
+ delta.content = None
+ delta.tool_calls = [tool_call_mock]
+
+ choice = Mock(spec=["delta", "finish_reason"])
+ choice.delta = delta
+ choice.finish_reason = None
+
+ chunk = Mock(spec=["choices"])
+ chunk.choices = [choice]
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 0
+ assert tool_call is not None
+ assert tool_call["id"] == "t1"
+ assert tool_call["name"] == "test_tool"
+
+
+def test_handle_stream_chunk_tool_call_arguments(provider):
+ current_tool = {"id": "t1", "name": "test", "arguments": '{"ke'}
+ chunk = Mock(
+ choices=[
+ Mock(
+ delta=Mock(
+ content=None,
+ tool_calls=[
+ Mock(
+ id=None,
+ function=Mock(name=None, arguments='y": "val"}'),
+ )
+ ],
+ ),
+ finish_reason=None,
+ )
+ ]
+ )
+
+ events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
+
+ assert len(events) == 0
+ assert tool_call["arguments"] == '{"key": "val"}'
+
+
+def test_handle_stream_chunk_finish_with_tool_call(provider):
+ current_tool = {"id": "t1", "name": "test", "arguments": '{"key": "value"}'}
+ chunk = Mock(
+ choices=[
+ Mock(
+ delta=Mock(content=None, tool_calls=None),
+ finish_reason="tool_calls",
+ )
+ ]
+ )
+
+ events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
+
+ assert len(events) == 1
+ assert events[0].type == "tool_use"
+ assert events[0].data["id"] == "t1"
+ assert events[0].data["input"] == {"key": "value"}
+ assert tool_call is None
+
+
+def test_handle_stream_chunk_empty_choices(provider):
+ chunk = Mock(choices=[])
+
+ events, tool_call = provider._handle_stream_chunk(chunk, None)
+
+ assert len(events) == 0
+ assert tool_call is None
+
+
+def test_generate_basic(provider, mock_openai_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ # The conftest fixture already sets up the mock response
+ result = provider.generate(messages)
+
+ assert isinstance(result, str)
+ assert len(result) > 0
+ provider.client.chat.completions.create.assert_called_once()
+
+
+def test_stream_basic(provider, mock_openai_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ events = list(provider.stream(messages))
+
+ # Should get text events and done event
+ assert len(events) > 0
+ text_events = [e for e in events if e.type == "text"]
+ assert len(text_events) > 0
+ assert events[-1].type == "done"
+
+
+@pytest.mark.asyncio
+async def test_agenerate_basic(provider, mock_openai_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ # Mock the async client
+ mock_response = Mock(choices=[Mock(message=Mock(content="async response"))])
+ provider.async_client.chat.completions.create = Mock(return_value=mock_response)
+
+ result = await provider.agenerate(messages)
+
+ assert result == "async response"
+
+
+@pytest.mark.asyncio
+async def test_astream_basic(provider, mock_openai_client):
+ messages = [Message(role=MessageRole.USER, content="test")]
+
+ # Mock async streaming
+ async def async_stream():
+ yield Mock(
+ choices=[
+ Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
+ ]
+ )
+ yield Mock(
+ choices=[
+ Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
+ ]
+ )
+
+ provider.async_client.chat.completions.create = Mock(return_value=async_stream())
+
+ events = []
+ async for event in provider.astream(messages):
+ events.append(event)
+
+ assert len(events) > 0
+ text_events = [e for e in events if e.type == "text"]
+ assert len(text_events) > 0
+
+
+def test_stream_with_tool_call(provider, mock_openai_client):
+ """Test streaming with a complete tool call."""
+
+ def stream_with_tool(*args, **kwargs):
+ if kwargs.get("stream"):
+ # First chunk - tool call start
+ function1 = Mock(spec=["name", "arguments"])
+ function1.name = "test_tool"
+ function1.arguments = ""
+
+ tool_call1 = Mock(spec=["id", "function"])
+ tool_call1.id = "t1"
+ tool_call1.function = function1
+
+ delta1 = Mock(spec=["content", "tool_calls"])
+ delta1.content = None
+ delta1.tool_calls = [tool_call1]
+
+ choice1 = Mock(spec=["delta", "finish_reason"])
+ choice1.delta = delta1
+ choice1.finish_reason = None
+
+ chunk1 = Mock(spec=["choices"])
+ chunk1.choices = [choice1]
+
+ # Second chunk - tool arguments
+ function2 = Mock(spec=["name", "arguments"])
+ function2.name = None
+ function2.arguments = '{"arg": "val"}'
+
+ tool_call2 = Mock(spec=["id", "function"])
+ tool_call2.id = None
+ tool_call2.function = function2
+
+ delta2 = Mock(spec=["content", "tool_calls"])
+ delta2.content = None
+ delta2.tool_calls = [tool_call2]
+
+ choice2 = Mock(spec=["delta", "finish_reason"])
+ choice2.delta = delta2
+ choice2.finish_reason = None
+
+ chunk2 = Mock(spec=["choices"])
+ chunk2.choices = [choice2]
+
+ # Third chunk - finish
+ delta3 = Mock(spec=["content", "tool_calls"])
+ delta3.content = None
+ delta3.tool_calls = None
+
+ choice3 = Mock(spec=["delta", "finish_reason"])
+ choice3.delta = delta3
+ choice3.finish_reason = "tool_calls"
+
+ chunk3 = Mock(spec=["choices"])
+ chunk3.choices = [choice3]
+
+ return iter([chunk1, chunk2, chunk3])
+
+ provider.client.chat.completions.create.side_effect = stream_with_tool
+
+ messages = [Message(role=MessageRole.USER, content="test")]
+ events = list(provider.stream(messages))
+
+ tool_events = [e for e in events if e.type == "tool_use"]
+ assert len(tool_events) == 1
+ assert tool_events[0].data["id"] == "t1"
+ assert tool_events[0].data["name"] == "test_tool"
+ assert tool_events[0].data["input"] == {"arg": "val"}
+
+
+def test_encode_image(provider):
+ image = Image.new("RGB", (10, 10), color="blue")
+
+ encoded = provider.encode_image(image)
+
+ assert isinstance(encoded, str)
+ assert len(encoded) > 0
+
+
+def test_encode_image_rgba(provider):
+ """RGBA images should be converted to RGB."""
+ image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128))
+
+ encoded = provider.encode_image(image)
+
+ assert isinstance(encoded, str)
+ assert len(encoded) > 0
diff --git a/tools/add_user.py b/tools/add_user.py
index 4272f45..8991edc 100644
--- a/tools/add_user.py
+++ b/tools/add_user.py
@@ -2,21 +2,41 @@
import argparse
from memory.common.db.connection import make_session
-from memory.common.db.models.users import User
+from memory.common.db.models.users import HumanUser, BotUser
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--email", type=str, required=True)
- args.add_argument("--password", type=str, required=True)
args.add_argument("--name", type=str, required=True)
+ args.add_argument("--password", type=str, required=False)
+ args.add_argument("--bot", action="store_true", help="Create a bot user")
+ args.add_argument(
+ "--api-key",
+ type=str,
+ required=False,
+ help="API key for bot user (auto-generated if not provided)",
+ )
args = args.parse_args()
with make_session() as session:
- user = User.create_with_password(
- email=args.email, password=args.password, name=args.name
- )
+ if args.bot:
+ user = BotUser.create_with_api_key(
+ name=args.name, email=args.email, api_key=args.api_key
+ )
+ print(f"Bot user {args.email} created with API key: {user.api_key}")
+ else:
+ if not args.password:
+ raise ValueError("Password required for human users")
+ user = HumanUser.create_with_password(
+ email=args.email, password=args.password, name=args.name
+ )
+ print(f"Human user {args.email} created")
+
session.add(user)
session.commit()
- print(f"User {args.email} created")
+ if args.bot:
+ print(f"Bot user {args.email} created with API key: {user.api_key}")
+ else:
+ print(f"Human user {args.email} created")