discord integration

This commit is contained in:
Daniel O'Connell 2025-10-20 03:47:13 +02:00
parent e68671deb4
commit 1606348d8b
32 changed files with 3472 additions and 212 deletions

View File

@ -1,8 +1,8 @@
"""add_discord_models
Revision ID: a8c8e8b17179
Revision ID: 7c6169fba146
Revises: c86079073c1d
Create Date: 2025-10-12 22:28:27.856164
Create Date: 2025-10-13 14:21:01.080948
"""
@ -13,7 +13,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a8c8e8b17179"
revision: str = "7c6169fba146"
down_revision: Union[str, None] = "c86079073c1d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -26,12 +26,6 @@ def upgrade() -> None:
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("member_count", sa.Integer(), nullable=True),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
@ -45,6 +39,17 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"track_messages", sa.Boolean(), server_default="true", nullable=False
),
sa.Column("ignore_messages", sa.Boolean(), nullable=True),
sa.Column(
"allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column(
"disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column("summary", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
@ -59,12 +64,6 @@ def upgrade() -> None:
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("channel_type", sa.Text(), nullable=False),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
@ -77,6 +76,17 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"track_messages", sa.Boolean(), server_default="true", nullable=False
),
sa.Column("ignore_messages", sa.Boolean(), nullable=True),
sa.Column(
"allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column(
"disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column("summary", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["server_id"],
["discord_servers.id"],
@ -92,12 +102,6 @@ def upgrade() -> None:
sa.Column("username", sa.Text(), nullable=False),
sa.Column("display_name", sa.Text(), nullable=True),
sa.Column("system_user_id", sa.Integer(), nullable=True),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
@ -110,6 +114,17 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"track_messages", sa.Boolean(), server_default="true", nullable=False
),
sa.Column("ignore_messages", sa.Boolean(), nullable=True),
sa.Column(
"allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column(
"disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column("summary", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["system_user_id"],
["users.id"],

View File

@ -0,0 +1,145 @@
"""seperate_user__models
Revision ID: 35a2c1b610b6
Revises: 7c6169fba146
Create Date: 2025-10-20 01:48:58.537881
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "35a2c1b610b6"
down_revision: Union[str, None] = "7c6169fba146"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"discord_message", sa.Column("from_id", sa.BigInteger(), nullable=False)
)
op.add_column(
"discord_message", sa.Column("recipient_id", sa.BigInteger(), nullable=False)
)
op.drop_index("discord_message_user_idx", table_name="discord_message")
op.create_index(
"discord_message_from_idx", "discord_message", ["from_id"], unique=False
)
op.create_index(
"discord_message_recipient_idx",
"discord_message",
["recipient_id"],
unique=False,
)
op.drop_constraint(
"discord_message_discord_user_id_fkey", "discord_message", type_="foreignkey"
)
op.create_foreign_key(
"discord_message_from_id_fkey",
"discord_message",
"discord_users",
["from_id"],
["id"],
)
op.create_foreign_key(
"discord_message_recipient_id_fkey",
"discord_message",
"discord_users",
["recipient_id"],
["id"],
)
op.drop_column("discord_message", "discord_user_id")
op.add_column(
"scheduled_llm_calls",
sa.Column("discord_channel_id", sa.BigInteger(), nullable=True),
)
op.add_column(
"scheduled_llm_calls",
sa.Column("discord_user_id", sa.BigInteger(), nullable=True),
)
op.create_foreign_key(
"scheduled_llm_calls_discord_user_id_fkey",
"scheduled_llm_calls",
"discord_users",
["discord_user_id"],
["id"],
)
op.create_foreign_key(
"scheduled_llm_calls_discord_channel_id_fkey",
"scheduled_llm_calls",
"discord_channels",
["discord_channel_id"],
["id"],
)
op.drop_column("scheduled_llm_calls", "discord_user")
op.drop_column("scheduled_llm_calls", "discord_channel")
op.add_column(
"users",
sa.Column("user_type", sa.String(), nullable=False, server_default="human"),
)
op.add_column("users", sa.Column("api_key", sa.String(), nullable=True))
op.alter_column("users", "password_hash", existing_type=sa.VARCHAR(), nullable=True)
op.create_unique_constraint("users_api_key_key", "users", ["api_key"])
op.drop_column("users", "discord_user_id")
def downgrade() -> None:
op.add_column(
"users",
sa.Column("discord_user_id", sa.VARCHAR(), autoincrement=False, nullable=True),
)
op.drop_constraint("users_api_key_key", "users", type_="unique")
op.alter_column(
"users", "password_hash", existing_type=sa.VARCHAR(), nullable=False
)
op.drop_column("users", "api_key")
op.drop_column("users", "user_type")
op.add_column(
"scheduled_llm_calls",
sa.Column("discord_channel", sa.VARCHAR(), autoincrement=False, nullable=True),
)
op.add_column(
"scheduled_llm_calls",
sa.Column("discord_user", sa.VARCHAR(), autoincrement=False, nullable=True),
)
op.drop_constraint(
"scheduled_llm_calls_discord_user_id_fkey",
"scheduled_llm_calls",
type_="foreignkey",
)
op.drop_constraint(
"scheduled_llm_calls_discord_channel_id_fkey",
"scheduled_llm_calls",
type_="foreignkey",
)
op.drop_column("scheduled_llm_calls", "discord_user_id")
op.drop_column("scheduled_llm_calls", "discord_channel_id")
op.add_column(
"discord_message",
sa.Column("discord_user_id", sa.BIGINT(), autoincrement=False, nullable=False),
)
op.drop_constraint(
"discord_message_from_id_fkey", "discord_message", type_="foreignkey"
)
op.drop_constraint(
"discord_message_recipient_id_fkey", "discord_message", type_="foreignkey"
)
op.create_foreign_key(
"discord_message_discord_user_id_fkey",
"discord_message",
"discord_users",
["discord_user_id"],
["id"],
)
op.drop_index("discord_message_recipient_idx", table_name="discord_message")
op.drop_index("discord_message_from_idx", table_name="discord_message")
op.create_index(
"discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
)
op.drop_column("discord_message", "recipient_id")
op.drop_column("discord_message", "from_id")

View File

@ -50,6 +50,8 @@ x-worker-base: &worker-base
OPENAI_API_KEY_FILE: /run/secrets/openai_key
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
DISCORD_COLLECTOR_SERVER_URL: ingest-hub
DISCORD_COLLECTOR_PORT: 8003
secrets: [ postgres_password, openai_key, anthropic_key, ssh_private_key, ssh_public_key, ssh_known_hosts ]
read_only: true
tmpfs:
@ -183,7 +185,6 @@ services:
dockerfile: docker/ingest_hub/Dockerfile
environment:
<<: *worker-env
DISCORD_API_PORT: 8000
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
DISCORD_NOTIFICATIONS_ENABLED: true
DISCORD_COLLECTOR_ENABLED: true

View File

@ -25,7 +25,7 @@ from memory.api.MCP.oauth_provider import (
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import OAuthState, UserSession
from memory.common.db.models.users import User
from memory.common.db.models.users import HumanUser
logger = logging.getLogger(__name__)
@ -126,7 +126,11 @@ async def handle_login(request: Request):
key: value for key, value in form.items() if key not in ["email", "password"]
}
with make_session() as session:
user = session.query(User).filter(User.email == form.get("email")).first()
user = (
session.query(HumanUser)
.filter(HumanUser.email == form.get("email"))
.first()
)
if not user or not user.is_valid_password(str(form.get("password", ""))):
logger.warning("Login failed - invalid credentials")
return login_form(request, oauth_params, "Invalid email or password")
@ -144,11 +148,7 @@ def get_current_user() -> dict:
return {"authenticated": False}
with make_session() as session:
user_session = (
session.query(UserSession)
.filter(UserSession.id == access_token.token)
.first()
)
user_session = session.query(UserSession).get(access_token.token)
if user_session and user_session.user:
user_info = user_session.user.serialize()

View File

@ -21,6 +21,7 @@ from memory.common.db.models.users import (
OAuthRefreshToken,
OAuthState,
User,
BotUser,
UserSession,
)
from memory.common.db.models.users import (
@ -92,7 +93,7 @@ def create_oauth_token(
"""Create an OAuth token response."""
return OAuthToken(
access_token=access_token,
token_type="bearer",
token_type="Bearer",
expires_in=ACCESS_TOKEN_LIFETIME,
refresh_token=refresh_token,
scope=" ".join(scopes),
@ -310,26 +311,37 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
return token
async def load_access_token(self, token: str) -> Optional[AccessToken]:
"""Load and validate an access token."""
"""Load and validate an access token (or bot API key)."""
with make_session() as session:
now = datetime.now(timezone.utc).replace(
tzinfo=None
) # Make naive for DB comparison
# Query for active (non-expired) session
# Try as OAuth access token first
user_session = session.query(UserSession).get(token)
if not user_session:
return None
if user_session:
now = datetime.now(timezone.utc).replace(
tzinfo=None
) # Make naive for DB comparison
if user_session.expires_at < now:
return None
if user_session.expires_at < now:
return None
return AccessToken(
token=token,
client_id=user_session.oauth_state.client_id,
scopes=user_session.oauth_state.scopes,
expires_at=int(user_session.expires_at.timestamp()),
)
return AccessToken(
token=token,
client_id=user_session.oauth_state.client_id,
scopes=user_session.oauth_state.scopes,
expires_at=int(user_session.expires_at.timestamp()),
)
# Try as bot API key
bot = session.query(User).filter(User.api_key == token).first()
if bot:
logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
return AccessToken(
token=token,
client_id=cast(str, bot.name or bot.email),
scopes=["read", "write"], # Bots get full access
expires_at=2147483647, # Far future (2038)
)
return None
async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str

View File

@ -9,7 +9,9 @@ from typing import Any
from memory.api.MCP.base import get_current_user
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.common.db.models.discord import DiscordChannel, DiscordUser
from memory.api.MCP.base import mcp
from memory.discord.schedule import schedule_discord_message
logger = logging.getLogger(__name__)
@ -17,7 +19,7 @@ logger = logging.getLogger(__name__)
@mcp.tool()
async def schedule_message(
scheduled_time: str,
message: str | None = None,
message: str,
model: str | None = None,
topic: str | None = None,
discord_channel: str | None = None,
@ -56,7 +58,8 @@ async def schedule_message(
if not user_id:
raise ValueError("User not found")
discord_user = current_user.get("user", {}).get("discord_user_id")
discord_users = current_user.get("user", {}).get("discord_users")
discord_user = discord_users and next(iter(discord_users.keys()), None)
if not discord_user and not discord_channel:
raise ValueError("Either discord_user or discord_channel must be provided")
@ -69,27 +72,20 @@ async def schedule_message(
except ValueError:
raise ValueError("Invalid datetime format")
# Validate that the scheduled time is in the future
# Compare with naive datetime since we store naive in the database
current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
if scheduled_dt <= current_time_naive:
raise ValueError("Scheduled time must be in the future")
with make_session() as session:
# Create the scheduled call
scheduled_call = ScheduledLLMCall(
user_id=user_id,
scheduled_call = schedule_discord_message(
session=session,
scheduled_time=scheduled_dt,
message=message,
topic=topic,
user_id=current_user.get("user", {}).get("user_id"),
model=model,
system_prompt=system_prompt,
topic=topic,
discord_channel=discord_channel,
discord_user=discord_user,
data=metadata or {},
system_prompt=system_prompt,
metadata=metadata,
)
session.add(scheduled_call)
session.commit()
return {

View File

@ -7,7 +7,7 @@ from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session
from memory.common.db.connection import get_session, make_session
from memory.common.db.models.users import User, UserSession
from memory.common.db.models.users import User, HumanUser, BotUser, UserSession
logger = logging.getLogger(__name__)
@ -91,14 +91,14 @@ def get_current_user(request: Request, db: DBSession = Depends(get_session)) ->
return user
def create_user(email: str, password: str, name: str, db: DBSession) -> User:
"""Create a new user"""
def create_user(email: str, password: str, name: str, db: DBSession) -> HumanUser:
"""Create a new human user"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == email).first()
if existing_user:
raise HTTPException(status_code=400, detail="User already exists")
user = User.create_with_password(email, name, password)
user = HumanUser.create_with_password(email, name, password)
db.add(user)
db.commit()
db.refresh(user)
@ -106,14 +106,19 @@ def create_user(email: str, password: str, name: str, db: DBSession) -> User:
return user
def authenticate_user(email: str, password: str, db: DBSession) -> User | None:
"""Authenticate a user by email and password"""
user = db.query(User).filter(User.email == email).first()
def authenticate_user(email: str, password: str, db: DBSession) -> HumanUser | None:
"""Authenticate a human user by email and password"""
user = db.query(HumanUser).filter(HumanUser.email == email).first()
if user and user.is_valid_password(password):
return user
return None
def authenticate_bot(api_key: str, db: DBSession) -> BotUser | None:
"""Authenticate a bot by API key"""
return db.query(BotUser).filter(BotUser.api_key == api_key).first()
@router.api_route("/logout", methods=["GET", "POST"])
def logout(request: Request, db: DBSession = Depends(get_session)):
"""Logout and clear session"""

View File

@ -30,6 +30,11 @@ from memory.common.db.models.source_items import (
NotePayload,
ForumPostPayload,
)
from memory.common.db.models.discord import (
DiscordServer,
DiscordChannel,
DiscordUser,
)
from memory.common.db.models.observations import (
ObservationContradiction,
ReactionPattern,
@ -41,12 +46,12 @@ from memory.common.db.models.sources import (
Book,
ArticleFeed,
EmailAccount,
DiscordServer,
DiscordChannel,
DiscordUser,
)
from memory.common.db.models.users import (
User,
HumanUser,
BotUser,
DiscordBotUser,
UserSession,
OAuthClientInformation,
OAuthState,
@ -103,6 +108,9 @@ __all__ = [
"DiscordUser",
# Users
"User",
"HumanUser",
"BotUser",
"DiscordBotUser",
"UserSession",
"OAuthClientInformation",
"OAuthState",

View File

@ -0,0 +1,121 @@
"""
Database models for the Discord system.
"""
import textwrap
from sqlalchemy import (
ARRAY,
BigInteger,
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Text,
func,
)
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
class MessageProcessor:
track_messages = Column(Boolean, nullable=False, server_default="true")
ignore_messages = Column(Boolean, nullable=True, default=False)
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
summary = Column(
Text,
nullable=True,
doc=textwrap.dedent(
"""
A summary of this processor, made by and for AI systems.
The idea here is that AI systems can use this summary to keep notes on the given processor.
These should automatically be injected into the context of the messages that are processed by this processor.
"""
),
)
def as_xml(self) -> str:
return (
textwrap.dedent("""
<{type}>
<name>{name}</name>
<summary>{summary}</summary>
</{type}>
""")
.format(
type=self.__class__.__tablename__[8:], # type: ignore
name=getattr(self, "name", None) or getattr(self, "username", None),
summary=self.summary,
)
.strip()
)
class DiscordServer(Base, MessageProcessor):
"""Discord server configuration and metadata"""
__tablename__ = "discord_servers"
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
name = Column(Text, nullable=False)
description = Column(Text)
member_count = Column(Integer)
# Collection settings
last_sync_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
channels = relationship(
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
)
__table_args__ = (
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
)
class DiscordChannel(Base, MessageProcessor):
"""Discord channel metadata and configuration"""
__tablename__ = "discord_channels"
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
name = Column(Text, nullable=False)
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
# Collection settings (null = inherit from server)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
server = relationship("DiscordServer", back_populates="channels")
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
class DiscordUser(Base, MessageProcessor):
"""Discord user metadata and preferences"""
__tablename__ = "discord_users"
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
username = Column(Text, nullable=False)
display_name = Column(Text)
# Link to system user if registered
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
# Basic DM settings
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users")
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)

View File

@ -7,6 +7,7 @@ from sqlalchemy import (
String,
DateTime,
ForeignKey,
BigInteger,
JSON,
Text,
)
@ -37,8 +38,10 @@ class ScheduledLLMCall(Base):
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
# Discord configuration
discord_channel = Column(String, nullable=True)
discord_user = Column(String, nullable=True)
discord_channel_id = Column(
BigInteger, ForeignKey("discord_channels.id"), nullable=True
)
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=True)
# Execution status and results
status = Column(
@ -55,6 +58,8 @@ class ScheduledLLMCall(Base):
# Relationships
user = relationship("User")
discord_channel = relationship("DiscordChannel", foreign_keys=[discord_channel_id])
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
def serialize(self) -> Dict[str, Any]:
def print_datetime(dt: datetime | None) -> str | None:
@ -73,8 +78,8 @@ class ScheduledLLMCall(Base):
"message": self.message,
"system_prompt": self.system_prompt,
"allowed_tools": self.allowed_tools,
"discord_channel": self.discord_channel,
"discord_user": self.discord_user,
"discord_channel": self.discord_channel and self.discord_channel.name,
"discord_user": self.discord_user and self.discord_user.username,
"status": self.status,
"response": self.response,
"error_message": self.error_message,

View File

@ -286,7 +286,8 @@ class DiscordMessage(SourceItem):
sent_at = Column(DateTime(timezone=True), nullable=False)
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
from_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
recipient_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
# Discord-specific metadata
@ -303,11 +304,33 @@ class DiscordMessage(SourceItem):
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
server = relationship("DiscordServer", foreign_keys=[server_id])
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
from_user = relationship("DiscordUser", foreign_keys=[from_id])
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
@property
def allowed_tools(self) -> list[str]:
return (
(self.channel.allowed_tools if self.channel else [])
+ (self.from_user.allowed_tools if self.from_user else [])
+ (self.server.allowed_tools if self.server else [])
)
@property
def disallowed_tools(self) -> list[str]:
return (
(self.channel.disallowed_tools if self.channel else [])
+ (self.from_user.disallowed_tools if self.from_user else [])
+ (self.server.disallowed_tools if self.server else [])
)
def tool_allowed(self, tool: str) -> bool:
return not (self.disallowed_tools and tool in self.disallowed_tools) and (
not self.allowed_tools or tool in self.allowed_tools
)
@property
def title(self) -> str:
return f"{self.discord_user.username}: {self.content}"
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
__mapper_args__ = {
"polymorphic_identity": "discord_message",
@ -320,7 +343,8 @@ class DiscordMessage(SourceItem):
"server_id",
"channel_id",
),
Index("discord_message_user_idx", "discord_user_id"),
Index("discord_message_from_idx", "from_id"),
Index("discord_message_recipient_idx", "recipient_id"),
)
def _chunk_contents(self) -> Sequence[extract.DataChunk]:

View File

@ -125,74 +125,3 @@ class EmailAccount(Base):
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)
class MessageProcessor:
track_messages = Column(Boolean, nullable=False, server_default="true")
ignore_messages = Column(Boolean, nullable=True, default=False)
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
class DiscordServer(Base, MessageProcessor):
"""Discord server configuration and metadata"""
__tablename__ = "discord_servers"
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
name = Column(Text, nullable=False)
description = Column(Text)
member_count = Column(Integer)
# Collection settings
last_sync_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
channels = relationship(
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
)
__table_args__ = (
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
)
class DiscordChannel(Base, MessageProcessor):
"""Discord channel metadata and configuration"""
__tablename__ = "discord_channels"
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
name = Column(Text, nullable=False)
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
# Collection settings (null = inherit from server)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
server = relationship("DiscordServer", back_populates="channels")
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
class DiscordUser(Base, MessageProcessor):
"""Discord user metadata and preferences"""
__tablename__ = "discord_users"
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
username = Column(Text, nullable=False)
display_name = Column(Text)
# Link to system user if registered
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
# Basic DM settings
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users")
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)

View File

@ -2,7 +2,6 @@ import hashlib
import secrets
from typing import cast
import uuid
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from memory.common.db.models.base import Base
from sqlalchemy import (
@ -14,6 +13,7 @@ from sqlalchemy import (
Boolean,
ARRAY,
Numeric,
CheckConstraint,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
@ -36,12 +36,21 @@ def verify_password(password: str, password_hash: str) -> bool:
class User(Base):
__tablename__ = "users"
__table_args__ = (
CheckConstraint(
"password_hash IS NOT NULL OR api_key IS NOT NULL",
name="user_has_auth_method",
),
)
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False)
email = Column(String, nullable=False, unique=True)
password_hash = Column(String, nullable=False)
discord_user_id = Column(String, nullable=True)
user_type = Column(String, nullable=False) # Discriminator column
# Make these nullable since subclasses will use them selectively
password_hash = Column(String, nullable=True)
api_key = Column(String, nullable=True, unique=True)
# Relationship to sessions
sessions = relationship(
@ -52,22 +61,86 @@ class User(Base):
)
discord_users = relationship("DiscordUser", back_populates="system_user")
__mapper_args__ = {
"polymorphic_on": user_type,
"polymorphic_identity": "user",
}
def serialize(self) -> dict:
return {
"user_id": self.id,
"name": self.name,
"email": self.email,
"discord_user_id": self.discord_user_id,
"user_type": self.user_type,
"discord_users": {
discord_user.id: discord_user.username
for discord_user in self.discord_users
},
}
class HumanUser(User):
"""Human user with password authentication"""
__mapper_args__ = {
"polymorphic_identity": "human",
}
def is_valid_password(self, password: str) -> bool:
"""Check if the provided password is valid for this user"""
return verify_password(password, cast(str, self.password_hash))
@classmethod
def create_with_password(cls, email: str, name: str, password: str) -> "User":
"""Create a new user with a hashed password"""
return cls(email=email, name=name, password_hash=hash_password(password))
def create_with_password(cls, email: str, name: str, password: str) -> "HumanUser":
"""Create a new human user with a hashed password"""
return cls(
email=email,
name=name,
password_hash=hash_password(password),
user_type="human",
)
class BotUser(User):
"""Bot user with API key authentication"""
__mapper_args__ = {
"polymorphic_identity": "bot",
}
@classmethod
def create_with_api_key(
cls, name: str, email: str, api_key: str | None = None
) -> "BotUser":
"""Create a new bot user with an API key"""
if api_key is None:
api_key = f"bot_{secrets.token_hex(32)}"
return cls(
name=name,
email=email,
api_key=api_key,
user_type=cls.__mapper_args__["polymorphic_identity"],
)
class DiscordBotUser(BotUser):
"""Bot user with API key authentication"""
__mapper_args__ = {
"polymorphic_identity": "discord_bot",
}
@classmethod
def create_with_api_key(
cls,
discord_users: list,
name: str,
email: str,
api_key: str | None = None,
) -> "DiscordBotUser":
bot = super().create_with_api_key(name, email, api_key)
bot.discord_users = discord_users
return bot
class UserSession(Base):

View File

@ -25,7 +25,7 @@ def send_dm(user_identifier: str, message: str) -> bool:
try:
response = requests.post(
f"{get_api_url()}/send_dm",
json={"user_identifier": user_identifier, "message": message},
json={"user": user_identifier, "message": message},
timeout=10,
)
response.raise_for_status()
@ -37,6 +37,24 @@ def send_dm(user_identifier: str, message: str) -> bool:
return False
def send_to_channel(channel_name: str, message: str) -> bool:
"""Send a DM via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message},
timeout=10,
)
response.raise_for_status()
result = response.json()
print("Result", result)
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to send to channel {channel_name}: {e}")
return False
def broadcast_message(channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API"""
try:

View File

@ -81,3 +81,28 @@ def truncate(content: str, target_tokens: int) -> str:
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content
# bla = 1
# from memory.common.llms import *
# from memory.common.llms.tools.discord import make_discord_tools
# from memory.common.db.connection import make_session
# from memory.common.db.models import *
# model = "anthropic/claude-sonnet-4-5"
# provider = create_provider(model=model)
# with make_session() as session:
# bot = session.query(DiscordBotUser).first()
# server = session.query(DiscordServer).first()
# channel = server.channels[0]
# tools = make_discord_tools(bot, None, channel, model)
# def demo(msg: str):
# messages = [
# Message(
# role=MessageRole.USER,
# content=msg,
# )
# ]
# for m in provider.stream_with_tools(messages, tools):
# print(m)

View File

@ -333,7 +333,6 @@ class AnthropicProvider(BaseLLMProvider):
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
print(kwargs)
try:
with self.client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None

View File

@ -599,6 +599,9 @@ class BaseLLMProvider(ABC):
tool_calls=tool_calls or None,
)
def as_messages(self, messages) -> list[Message]:
return [Message.user(text=msg) for msg in messages]
def create_provider(
model: str | None = None,

View File

@ -150,7 +150,7 @@ class OpenAIProvider(BaseLLMProvider):
def _convert_tools(
self, tools: list[ToolDefinition] | None
) -> Optional[list[dict[str, Any]]]:
) -> list[dict[str, Any]] | None:
"""
Convert our tool definitions to OpenAI format.
@ -179,7 +179,7 @@ class OpenAIProvider(BaseLLMProvider):
self,
messages: list[Message],
system_prompt: str | None,
tools: Optional[list[ToolDefinition]],
tools: list[ToolDefinition] | None,
settings: LLMSettings,
stream: bool = False,
) -> dict[str, Any]:
@ -270,7 +270,7 @@ class OpenAIProvider(BaseLLMProvider):
self,
chunk: Any,
current_tool_call: dict[str, Any] | None,
) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]:
) -> tuple[list[StreamEvent], dict[str, Any] | None]:
"""
Handle a single streaming chunk and return events and updated tool state.
@ -325,9 +325,9 @@ class OpenAIProvider(BaseLLMProvider):
def generate(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
@ -374,9 +374,9 @@ class OpenAIProvider(BaseLLMProvider):
async def agenerate(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
@ -394,9 +394,9 @@ class OpenAIProvider(BaseLLMProvider):
async def astream(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
@ -406,7 +406,7 @@ class OpenAIProvider(BaseLLMProvider):
try:
stream = await self.async_client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None
current_tool_call: dict[str, Any] | None = None
async for chunk in stream:
events, current_tool_call = self._handle_stream_chunk(

View File

@ -0,0 +1,231 @@
"""Discord tool for interacting with Discord."""
import textwrap
from datetime import datetime
from typing import Literal, cast
from memory.discord.messages import (
upsert_scheduled_message,
comm_channel_prompt,
previous_messages,
)
from sqlalchemy import BigInteger
from memory.common.db.connection import make_session
from memory.common.db.models import (
DiscordServer,
DiscordChannel,
DiscordUser,
BotUser,
)
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
UpdateSummaryType = Literal["server", "channel", "user"]
def handle_update_summary_call(
type: UpdateSummaryType, item_id: BigInteger
) -> ToolHandler:
models = {
"server": DiscordServer,
"channel": DiscordChannel,
"user": DiscordUser,
}
def handler(input: ToolInput = None) -> str:
if isinstance(input, dict):
summary = input.get("summary") or str(input)
else:
summary = str(input)
try:
with make_session() as session:
model = models[type]
model = session.get(model, item_id)
model.summary = summary # type: ignore
session.commit()
except Exception as e:
return f"Error updating summary: {e}"
return "Updated summary"
handler.__doc__ = textwrap.dedent("""
Handle a {type} summary update tool call.
Args:
summary: The new summary of the Discord {type}
Returns:
Response string
""").format(type=type)
return handler
def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefinition:
return ToolDefinition(
name=f"update_{type}_summary",
description=textwrap.dedent("""
Use this to update the summary of this Discord {type} that is added to your context.
This will overwrite the previous summary.
""").format(type=type),
input_schema={
"type": "object",
"properties": {
"summary": {
"type": "string",
"description": f"The new summary of the Discord {type}",
}
},
"required": [],
},
function=handle_update_summary_call(type, item_id),
)
def schedule_message(
user_id: int,
user: int | None,
channel: int | None,
model: str,
message: str,
date_time: datetime,
) -> str:
with make_session() as session:
call = upsert_scheduled_message(
session,
scheduled_time=date_time,
message=message,
user_id=user_id,
model=model,
discord_user=user,
discord_channel=channel,
system_prompt=comm_channel_prompt(session, user, channel),
)
session.commit()
return cast(str, call.id)
def make_message_scheduler(
bot: BotUser, user: int | None, channel: int | None, model: str
) -> ToolDefinition:
bot_id = cast(int, bot.id)
if user:
channel_type = "from your chat with this user"
elif channel:
channel_type = "in this channel"
else:
raise ValueError("Either user or channel must be provided")
def handler(input: ToolInput) -> str:
if not isinstance(input, dict):
raise ValueError("Input must be a dictionary")
try:
time = datetime.fromisoformat(input["date_time"])
except ValueError:
raise ValueError("Invalid date time format")
except KeyError:
raise ValueError("Date time is required")
return schedule_message(bot_id, user, channel, model, input["message"], time)
return ToolDefinition(
name="schedule_message",
description=textwrap.dedent("""
Use this to schedule a message to be sent to yourself.
At the specified date and time, your message will be sent to you, along with the most
recent messages {channel_type}.
Normally you will be called with any incoming messages. But sometimes you might want to be
able to trigger a call to yourself at a specific time, rather than waiting for the next call.
This tool allows you to do that.
So for example, if you were chatting with a Discord user, and you ask a question which needs to
be answered right away, you can use this tool to schedule a check in 5 minutes time, to remind
the user to answer the question.
""").format(channel_type=channel_type),
input_schema={
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "The message to send",
},
"date_time": {
"type": "string",
"description": "The date and time to send the message in ISO format (e.g., 2025-01-01T00:00:00Z)",
},
},
},
function=handler,
)
def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition:
if user:
channel_type = "from your chat with this user"
elif channel:
channel_type = "in this channel"
else:
raise ValueError("Either user or channel must be provided")
def handler(input: ToolInput) -> str:
if not isinstance(input, dict):
raise ValueError("Input must be a dictionary")
try:
max_messages = int(input.get("max_messages") or 10)
offset = int(input.get("offset") or 0)
except ValueError:
raise ValueError("Max messages and offset must be integers")
if max_messages <= 0:
raise ValueError("Max messages must be greater than 0")
if offset < 0:
raise ValueError("Offset must be greater than or equal to 0")
with make_session() as session:
messages = previous_messages(session, user, channel, max_messages, offset)
return "\n\n".join([msg.title for msg in messages])
return ToolDefinition(
name="previous_messages",
description=f"Get the previous N messages {channel_type}.",
input_schema={
"type": "object",
"properties": {
"max_messages": {
"type": "number",
"description": "The maximum number of messages to return",
"default": 10,
},
"offset": {
"type": "number",
"description": "The number of messages to offset the result by",
"default": 0,
},
},
},
function=handler,
)
def make_discord_tools(
bot: BotUser,
author: DiscordUser | None,
channel: DiscordChannel | None,
model: str,
) -> dict[str, ToolDefinition]:
author_id = author and author.id
channel_id = channel and channel.id
tools = [
make_message_scheduler(bot, author_id, channel_id, model),
make_prev_messages_tool(author_id, channel_id),
make_summary_tool("channel", channel_id),
]
if author:
tools += [make_summary_tool("user", author_id)]
if channel and channel.server:
tools += [
make_summary_tool("server", cast(BigInteger, channel.server_id)),
]
return {tool.name: tool for tool in tools}

View File

@ -133,7 +133,7 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip()
else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
@ -173,11 +173,14 @@ DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5")
DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10))
# Discord collector settings
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8000))
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))

View File

@ -13,7 +13,7 @@ from sqlalchemy.orm import Session, scoped_session
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.sources import (
from memory.common.db.models import (
DiscordServer,
DiscordChannel,
DiscordUser,
@ -227,6 +227,7 @@ class MessageCollector(commands.Bot):
message_id=message.id,
channel_id=message.channel.id,
author_id=message.author.id,
recipient_id=self.user and self.user.id,
server_id=message.guild.id if message.guild else None,
content=message.content or "",
sent_at=message.created_at.isoformat(),

View File

@ -0,0 +1,205 @@
import logging
import textwrap
from datetime import datetime, timezone
from typing import Any, cast
from sqlalchemy.orm import Session, scoped_session
from memory.common.db.models import (
DiscordChannel,
DiscordUser,
ScheduledLLMCall,
DiscordMessage,
)
logger = logging.getLogger(__name__)
DiscordEntity = DiscordChannel | DiscordUser | str | int | None
def resolve_discord_user(
session: Session | scoped_session, entity: DiscordEntity
) -> DiscordUser | None:
if not entity:
return None
if isinstance(entity, DiscordUser):
return entity
if isinstance(entity, int):
return session.get(DiscordUser, entity)
entity = session.query(DiscordUser).filter(DiscordUser.username == entity).first()
if not entity:
entity = DiscordUser(id=entity, username=entity)
session.add(entity)
return entity
def resolve_discord_channel(
session: Session | scoped_session, entity: DiscordEntity
) -> DiscordChannel | None:
if not entity:
return None
if isinstance(entity, DiscordChannel):
return entity
if isinstance(entity, int):
return session.get(DiscordChannel, entity)
return session.query(DiscordChannel).filter(DiscordChannel.name == entity).first()
def schedule_discord_message(
session: Session | scoped_session,
scheduled_time: datetime,
message: str,
user_id: int,
model: str | None = None,
topic: str | None = None,
discord_user: DiscordEntity = None,
discord_channel: DiscordEntity = None,
system_prompt: str | None = None,
metadata: dict[str, Any] | None = None,
) -> ScheduledLLMCall:
discord_user = resolve_discord_user(session, discord_user)
discord_channel = resolve_discord_channel(session, discord_channel)
if not discord_user and not discord_channel:
raise ValueError("Either discord_user or discord_channel must be provided")
# Validate that the scheduled time is in the future
# Compare with naive datetime since we store naive in the database
current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
if scheduled_time.replace(tzinfo=None) <= current_time_naive:
raise ValueError("Scheduled time must be in the future")
# Create the scheduled call
scheduled_call = ScheduledLLMCall(
user_id=user_id,
scheduled_time=scheduled_time,
message=message,
topic=topic,
model=model,
system_prompt=system_prompt,
discord_channel=resolve_discord_channel(session, discord_channel),
discord_user=resolve_discord_user(session, discord_user),
data=metadata or {},
)
session.add(scheduled_call)
return scheduled_call
def upsert_scheduled_message(
session: Session | scoped_session,
scheduled_time: datetime,
message: str,
user_id: int,
model: str | None = None,
topic: str | None = None,
discord_user: DiscordEntity = None,
discord_channel: DiscordEntity = None,
system_prompt: str | None = None,
metadata: dict[str, Any] | None = None,
) -> ScheduledLLMCall:
discord_user = resolve_discord_user(session, discord_user)
discord_channel = resolve_discord_channel(session, discord_channel)
prev_call = (
session.query(ScheduledLLMCall)
.filter(
ScheduledLLMCall.user_id == user_id,
ScheduledLLMCall.model == model,
ScheduledLLMCall.discord_user_id == (discord_user and discord_user.id),
ScheduledLLMCall.discord_channel_id
== (discord_channel and discord_channel.id),
)
.first()
)
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
print(f"naive_scheduled_time: {naive_scheduled_time}")
print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
prev_call.status = "cancelled" # type: ignore
return schedule_discord_message(
session,
scheduled_time,
message,
user_id=user_id,
model=model,
topic=topic,
discord_user=discord_user,
discord_channel=discord_channel,
system_prompt=system_prompt,
metadata=metadata,
)
def previous_messages(
session: Session | scoped_session,
user_id: int | None,
channel_id: int | None,
max_messages: int = 10,
offset: int = 0,
) -> list[DiscordMessage]:
messages = session.query(DiscordMessage)
if user_id:
messages = messages.filter(DiscordMessage.recipient_id == user_id)
if channel_id:
messages = messages.filter(DiscordMessage.channel_id == channel_id)
return list(
reversed(
messages.order_by(DiscordMessage.sent_at.desc())
.offset(offset)
.limit(max_messages)
.all()
)
)
def comm_channel_prompt(
session: Session | scoped_session,
user: DiscordEntity,
channel: DiscordEntity,
max_messages: int = 10,
) -> str:
user = resolve_discord_user(session, user)
channel = resolve_discord_channel(session, channel)
messages = previous_messages(
session, user and user.id, channel and channel.id, max_messages
)
server_context = ""
if channel and channel.server:
server_context = textwrap.dedent("""
Here are your previous notes on the server:
<server_context>
{summary}
</server_context>
""").format(summary=channel.server.summary)
if channel:
server_context += textwrap.dedent("""
Here are your previous notes on the channel:
<channel_context>
{summary}
</channel_context>
""").format(summary=channel.summary)
if messages:
server_context += textwrap.dedent("""
Here are your previous notes on the users:
<user_notes>
{users}
</user_notes>
""").format(
users="\n".join({msg.from_user.as_xml() for msg in messages}),
)
return textwrap.dedent("""
You are a bot communicating on Discord.
{server_context}
Whenever something worth remembering is said, you should add a note to the appropriate context - use
this to track your understanding of the conversation and those taking part in it.
You will be given the last {max_messages} messages in the conversation.
Please react to them appropriately. You can return an empty response if you don't have anything to say.
""").format(server_context=server_context, max_messages=max_messages)

View File

@ -4,25 +4,31 @@ Celery tasks for Discord message processing.
import hashlib
import logging
import textwrap
from datetime import datetime
from typing import Any
from memory.common.celery_app import app
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMessage, DiscordUser
from memory.workers.tasks.content_processing import (
safe_task_execution,
check_content_exists,
create_task_result,
process_content_item,
)
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.orm import Session, scoped_session
from memory.common import discord, settings
from memory.common.celery_app import (
ADD_DISCORD_MESSAGE,
EDIT_DISCORD_MESSAGE,
PROCESS_DISCORD_MESSAGE,
app,
)
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMessage, DiscordUser
from memory.common.llms.base import create_provider
from memory.common.llms.tools.discord import make_discord_tools
from memory.discord.messages import comm_channel_prompt, previous_messages
from memory.workers.tasks.content_processing import (
check_content_exists,
create_task_result,
process_content_item,
safe_task_execution,
)
from memory.common import settings
from sqlalchemy.orm import Session, scoped_session
logger = logging.getLogger(__name__)
@ -32,7 +38,7 @@ def get_prev(
) -> list[str]:
prev = (
session.query(DiscordUser.username, DiscordMessage.content)
.join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
.join(DiscordUser, DiscordMessage.from_id == DiscordUser.id)
.filter(
DiscordMessage.channel_id == channel_id,
DiscordMessage.sent_at < sent_at,
@ -45,20 +51,54 @@ def get_prev(
def should_process(message: DiscordMessage) -> bool:
return (
if not (
settings.DISCORD_PROCESS_MESSAGES
and settings.DISCORD_NOTIFICATIONS_ENABLED
and not (
(message.server and message.server.ignore_messages)
or (message.channel and message.channel.ignore_messages)
or (message.discord_user and message.discord_user.ignore_messages)
or (message.from_user and message.from_user.ignore_messages)
)
)
):
return False
provider = create_provider(model=settings.SUMMARIZER_MODEL)
with make_session() as session:
system_prompt = comm_channel_prompt(
session, message.recipient_user, message.channel
)
messages = previous_messages(
session,
message.recipient_user and message.recipient_user.id,
message.channel and message.channel.id,
max_messages=10,
)
msg = textwrap.dedent("""
Should you continue the conversation with the user?
Please return "yes" or "no" as:
<response>yes</response>
or
<response>no</response>
""")
response = provider.generate(
messages=provider.as_messages([m.title for m in messages] + [msg]),
system_prompt=system_prompt,
)
return "<response>yes</response>" in "".join(response.lower().split())
@app.task(name=PROCESS_DISCORD_MESSAGE)
@safe_task_execution
def process_discord_message(message_id: int) -> dict[str, Any]:
"""
Process a Discord message.
This task is queued by the Discord collector when messages are received.
"""
logger.info(f"Processing Discord message {message_id}")
with make_session() as session:
@ -71,7 +111,39 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id,
}
print("Processing message", discord_message.id, discord_message.content)
tools = make_discord_tools(
discord_message.recipient_user,
discord_message.from_user,
discord_message.channel,
model=settings.DISCORD_MODEL,
)
tools = {
name: tool
for name, tool in tools.items()
if discord_message.tool_allowed(name)
}
system_prompt = comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel
)
messages = previous_messages(
session,
discord_message.recipient_user and discord_message.recipient_user.id,
discord_message.channel and discord_message.channel.id,
max_messages=10,
)
provider = create_provider(model=settings.DISCORD_MODEL)
turn = provider.run_with_tools(
messages=provider.as_messages([m.title for m in messages]),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
)
if not turn.response:
pass
elif discord_message.channel.server:
discord.send_to_channel(discord_message.channel.name, turn.response)
else:
discord.send_dm(discord_message.from_user.username, turn.response)
return {
"status": "processed",
@ -88,6 +160,7 @@ def add_discord_message(
content: str,
sent_at: str,
server_id: int | None = None,
recipient_id: int | None = None,
message_reference_id: int | None = None,
) -> dict[str, Any]:
"""
@ -108,7 +181,8 @@ def add_discord_message(
channel_id=channel_id,
sent_at=sent_at_dt,
server_id=server_id,
discord_user_id=author_id,
from_id=author_id,
recipient_id=recipient_id,
message_id=message_id,
message_type="reply" if message_reference_id else "default",
reply_to_message_id=message_reference_id,
@ -125,7 +199,15 @@ def add_discord_message(
if channel_id:
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
result = process_content_item(discord_message, session)
try:
result = process_content_item(discord_message, session)
except sqlalchemy_exc.IntegrityError as e:
logger.error(f"Integrity error adding Discord message {message_id}: {e}")
return {
"status": "error",
"error": "Integrity error",
"message_id": message_id,
}
if should_process(discord_message):
process_discord_message.delay(discord_message.id)

View File

@ -37,12 +37,12 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)"
if discord_user := cast(str, scheduled_call.discord_user):
logger.info(f"Sending DM to {discord_user}: {message}")
discord.send_dm(discord_user, message)
elif discord_channel := cast(str, scheduled_call.discord_channel):
logger.info(f"Broadcasting message to {discord_channel}: {message}")
discord.broadcast_message(discord_channel, message)
if discord_user := scheduled_call.discord_user:
logger.info(f"Sending DM to {discord_user.username}: {message}")
discord.send_dm(discord_user.username, message)
elif discord_channel := scheduled_call.discord_channel:
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
discord.broadcast_message(discord_channel.name, message)
else:
logger.warning(
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
@ -62,11 +62,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
with make_session() as session:
# Fetch the scheduled call
scheduled_call = (
session.query(ScheduledLLMCall)
.filter(ScheduledLLMCall.id == scheduled_call_id)
.first()
)
scheduled_call = session.query(ScheduledLLMCall).get(scheduled_call_id)
if not scheduled_call:
logger.error(f"Scheduled call {scheduled_call_id} not found")

View File

@ -254,17 +254,59 @@ def mock_openai_client():
with patch.object(openai, "OpenAI", autospec=True) as mock_client:
client = mock_client()
client.chat = Mock()
# Mock non-streaming response
client.chat.completions.create = Mock(
return_value=Mock(
choices=[
Mock(
message=Mock(
content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
)
),
finish_reason=None,
)
]
)
)
# Store original side_effect for potential override
def streaming_response(*args, **kwargs):
if kwargs.get("stream"):
# Return mock streaming chunks
return iter(
[
Mock(
choices=[
Mock(
delta=Mock(content="test", tool_calls=None),
finish_reason=None,
)
]
),
Mock(
choices=[
Mock(
delta=Mock(content=" response", tool_calls=None),
finish_reason="stop",
)
]
),
]
)
else:
# Return non-streaming response
return Mock(
choices=[
Mock(
message=Mock(
content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
),
finish_reason=None,
)
]
)
client.chat.completions.create.side_effect = streaming_response
yield client

View File

View File

@ -0,0 +1,552 @@
"""Comprehensive tests for Anthropic stream event parsing."""
import pytest
from unittest.mock import Mock
from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.base import StreamEvent
@pytest.fixture
def provider():
return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229")
# Content Block Start Tests
@pytest.mark.parametrize(
"block_type,block_attrs,expected_tool_use",
[
(
"tool_use",
{"id": "tool-1", "name": "search", "input": {}},
{
"id": "tool-1",
"name": "search",
"input": {},
"server_name": None,
"is_server_call": False,
},
),
(
"mcp_tool_use",
{
"id": "mcp-1",
"name": "mcp_search",
"input": {},
"server_name": "mcp-server",
},
{
"id": "mcp-1",
"name": "mcp_search",
"input": {},
"server_name": "mcp-server",
"is_server_call": True,
},
),
(
"server_tool_use",
{
"id": "srv-1",
"name": "server_action",
"input": {},
"server_name": "custom-server",
},
{
"id": "srv-1",
"name": "server_action",
"input": {},
"server_name": "custom-server",
"is_server_call": True,
},
),
],
)
def test_content_block_start_tool_types(
provider, block_type, block_attrs, expected_tool_use
):
"""Different tool types should be tracked correctly."""
block = Mock(spec=["type"] + list(block_attrs.keys()))
block.type = block_type
for key, value in block_attrs.items():
setattr(block, key, value)
event = Mock(spec=["type", "content_block"])
event.type = "content_block_start"
event.content_block = block
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
assert tool_use == expected_tool_use
def test_content_block_start_tool_without_input(provider):
"""Tool use without input field should initialize as empty string."""
block = Mock(spec=["type", "id", "name"])
block.type = "tool_use"
block.id = "tool-2"
block.name = "calculate"
event = Mock(spec=["type", "content_block"])
event.type = "content_block_start"
event.content_block = block
stream_event, tool_use = provider._handle_stream_event(event, None)
assert tool_use["input"] == ""
def test_content_block_start_tool_result(provider):
"""Tool result blocks should emit tool_result event."""
block = Mock(spec=["tool_use_id", "content"])
block.tool_use_id = "tool-1"
block.content = "Result content"
event = Mock(spec=["type", "content_block"])
event.type = "content_block_start"
event.content_block = block
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is not None
assert stream_event.type == "tool_result"
assert stream_event.data == {"id": "tool-1", "result": "Result content"}
@pytest.mark.parametrize(
"has_content_block,block_type",
[
(False, None),
(True, "unknown_type"),
],
)
def test_content_block_start_ignored_cases(provider, has_content_block, block_type):
"""Events without content_block or with unknown types should be ignored."""
event = Mock(spec=["type", "content_block"] if has_content_block else ["type"])
event.type = "content_block_start"
if has_content_block:
block = Mock(spec=["type"])
block.type = block_type
event.content_block = block
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
assert tool_use is None
# Content Block Delta Tests
@pytest.mark.parametrize(
"delta_type,delta_attr,attr_value,expected_type,expected_data",
[
("text_delta", "text", "Hello world", "text", "Hello world"),
("text_delta", "text", "", "text", ""),
(
"thinking_delta",
"thinking",
"Let me think...",
"thinking",
"Let me think...",
),
("signature_delta", "signature", "sig-12345", "thinking", None),
],
)
def test_content_block_delta_types(
provider, delta_type, delta_attr, attr_value, expected_type, expected_data
):
"""Different delta types should emit appropriate events."""
delta = Mock(spec=["type", delta_attr])
delta.type = delta_type
setattr(delta, delta_attr, attr_value)
event = Mock(spec=["type", "delta"])
event.type = "content_block_delta"
event.delta = delta
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event.type == expected_type
if expected_type == "thinking" and delta_type == "signature_delta":
assert stream_event.signature == attr_value
else:
assert stream_event.data == expected_data
@pytest.mark.parametrize(
"current_tool,partial_json,expected_input",
[
(
{"id": "t1", "name": "search", "input": '{"query": "'},
'test"}',
'{"query": "test"}',
),
(
{"id": "t1", "name": "search", "input": '{"'},
'key": "value"}',
'{"key": "value"}',
),
(
{"id": "t1", "name": "search", "input": ""},
'{"query": "test"}',
'{"query": "test"}',
),
],
)
def test_content_block_delta_input_json_accumulation(
provider, current_tool, partial_json, expected_input
):
"""JSON delta should accumulate to tool input."""
delta = Mock(spec=["type", "partial_json"])
delta.type = "input_json_delta"
delta.partial_json = partial_json
event = Mock(spec=["type", "delta"])
event.type = "content_block_delta"
event.delta = delta
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
assert stream_event is None
assert tool_use["input"] == expected_input
def test_content_block_delta_input_json_without_tool(provider):
"""JSON delta without tool context should return None."""
delta = Mock(spec=["type", "partial_json"])
delta.type = "input_json_delta"
delta.partial_json = '{"key": "value"}'
event = Mock(spec=["type", "delta"])
event.type = "content_block_delta"
event.delta = delta
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
assert tool_use is None
def test_content_block_delta_input_json_with_dict_input(provider):
"""JSON delta shouldn't modify if input is already a dict."""
current_tool = {"id": "t1", "name": "search", "input": {"query": "test"}}
delta = Mock(spec=["type", "partial_json"])
delta.type = "input_json_delta"
delta.partial_json = ', "extra": "data"'
event = Mock(spec=["type", "delta"])
event.type = "content_block_delta"
event.delta = delta
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
assert tool_use["input"] == {"query": "test"}
@pytest.mark.parametrize(
"has_delta,delta_type",
[
(False, None),
(True, "unknown_delta"),
],
)
def test_content_block_delta_ignored_cases(provider, has_delta, delta_type):
"""Events without delta or with unknown types should be ignored."""
event = Mock(spec=["type", "delta"] if has_delta else ["type"])
event.type = "content_block_delta"
if has_delta:
delta = Mock(spec=["type"])
delta.type = delta_type
event.delta = delta
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
# Content Block Stop Tests
@pytest.mark.parametrize(
"input_value,has_content_block,expected_input",
[
("", False, {}),
(" \n\t ", False, {}),
('{"invalid": json}', False, {}),
('{"query": "test", "limit": 10}', False, {"query": "test", "limit": 10}),
(
'{"filters": {"type": "user", "status": ["active", "pending"]}, "limit": 100}',
False,
{
"filters": {"type": "user", "status": ["active", "pending"]},
"limit": 100,
},
),
("", True, {"query": "test"}),
],
)
def test_content_block_stop_tool_finalization(
provider, input_value, has_content_block, expected_input
):
"""Tool stop should parse or use provided input correctly."""
current_tool = {"id": "t1", "name": "search", "input": input_value}
event = Mock(spec=["type", "content_block"] if has_content_block else ["type"])
event.type = "content_block_stop"
if has_content_block:
block = Mock(spec=["input"])
block.input = {"query": "test"}
event.content_block = block
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
assert stream_event.type == "tool_use"
assert stream_event.data["input"] == expected_input
assert tool_use is None
def test_content_block_stop_with_server_info(provider):
"""Server tool info should be included in final event."""
current_tool = {
"id": "t1",
"name": "mcp_search",
"input": '{"q": "test"}',
"server_name": "mcp-server",
"is_server_call": True,
}
event = Mock(spec=["type"])
event.type = "content_block_stop"
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
assert stream_event.data["server_name"] == "mcp-server"
assert stream_event.data["is_server_call"] is True
def test_content_block_stop_without_tool(provider):
"""Stop without current tool should return None."""
event = Mock(spec=["type"])
event.type = "content_block_stop"
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
assert tool_use is None
# Message Delta Tests
def test_message_delta_max_tokens(provider):
"""Max tokens stop reason should emit error."""
delta = Mock(spec=["stop_reason"])
delta.stop_reason = "max_tokens"
event = Mock(spec=["type", "delta"])
event.type = "message_delta"
event.delta = delta
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event.type == "error"
assert "Max tokens" in stream_event.data
@pytest.mark.parametrize("stop_reason", ["end_turn", "stop_sequence", None])
def test_message_delta_other_stop_reasons(provider, stop_reason):
"""Other stop reasons should not emit error."""
delta = Mock(spec=["stop_reason"])
delta.stop_reason = stop_reason
event = Mock(spec=["type", "delta"])
event.type = "message_delta"
event.delta = delta
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
def test_message_delta_token_usage(provider):
"""Token usage should be logged but not emitted."""
usage = Mock(
spec=[
"input_tokens",
"output_tokens",
"cache_creation_input_tokens",
"cache_read_input_tokens",
]
)
usage.input_tokens = 100
usage.output_tokens = 50
usage.cache_creation_input_tokens = 10
usage.cache_read_input_tokens = 20
event = Mock(spec=["type", "usage"])
event.type = "message_delta"
event.usage = usage
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
def test_message_delta_empty(provider):
"""Message delta without delta or usage should return None."""
event = Mock(spec=["type"])
event.type = "message_delta"
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
# Message Stop Tests
@pytest.mark.parametrize(
"current_tool",
[
None,
{"id": "t1", "name": "search", "input": '{"incomplete'},
],
)
def test_message_stop(provider, current_tool):
"""Message stop should emit done regardless of incomplete tools."""
event = Mock(spec=["type"])
event.type = "message_stop"
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
assert stream_event.type == "done"
assert tool_use is None
# Error Handling Tests
@pytest.mark.parametrize(
"has_error,error_value,expected_message",
[
(True, "API rate limit exceeded", "rate limit"),
(False, None, "Unknown error"),
],
)
def test_error_events(provider, has_error, error_value, expected_message):
"""Error events should emit error StreamEvent."""
event = Mock(spec=["type", "error"] if has_error else ["type"])
event.type = "error"
if has_error:
event.error = error_value
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event.type == "error"
assert expected_message in stream_event.data
# Unknown Event Tests
@pytest.mark.parametrize(
"event_type",
["message_start", "future_event_type", None],
)
def test_unknown_or_ignored_events(provider, event_type):
"""Unknown event types should be logged but not fail."""
if event_type is None:
event = Mock(spec=[])
else:
event = Mock(spec=["type"])
event.type = event_type
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
# State Transition Tests
def test_complete_tool_call_sequence(provider):
"""Simulate a complete tool call from start to finish."""
# Start
block = Mock(spec=["type", "id", "name", "input"])
block.type = "tool_use"
block.id = "tool-1"
block.name = "search"
block.input = None
event1 = Mock(spec=["type", "content_block"])
event1.type = "content_block_start"
event1.content_block = block
_, tool_use = provider._handle_stream_event(event1, None)
assert tool_use["input"] == ""
# Delta 1
delta1 = Mock(spec=["type", "partial_json"])
delta1.type = "input_json_delta"
delta1.partial_json = '{"query":'
event2 = Mock(spec=["type", "delta"])
event2.type = "content_block_delta"
event2.delta = delta1
_, tool_use = provider._handle_stream_event(event2, tool_use)
assert tool_use["input"] == '{"query":'
# Delta 2
delta2 = Mock(spec=["type", "partial_json"])
delta2.type = "input_json_delta"
delta2.partial_json = ' "test"}'
event3 = Mock(spec=["type", "delta"])
event3.type = "content_block_delta"
event3.delta = delta2
_, tool_use = provider._handle_stream_event(event3, tool_use)
assert tool_use["input"] == '{"query": "test"}'
# Stop
event4 = Mock(spec=["type"])
event4.type = "content_block_stop"
stream_event, tool_use = provider._handle_stream_event(event4, tool_use)
assert stream_event.type == "tool_use"
assert stream_event.data["input"] == {"query": "test"}
assert tool_use is None
def test_text_and_thinking_mixed(provider):
"""Text and thinking deltas should be handled independently."""
delta1 = Mock(spec=["type", "text"])
delta1.type = "text_delta"
delta1.text = "Answer: "
event1 = Mock(spec=["type", "delta"])
event1.type = "content_block_delta"
event1.delta = delta1
event1_result, _ = provider._handle_stream_event(event1, None)
assert event1_result.type == "text"
delta2 = Mock(spec=["type", "thinking"])
delta2.type = "thinking_delta"
delta2.thinking = "reasoning..."
event2 = Mock(spec=["type", "delta"])
event2.type = "content_block_delta"
event2.delta = delta2
event2_result, _ = provider._handle_stream_event(event2, None)
assert event2_result.type == "thinking"

View File

@ -0,0 +1,440 @@
import pytest
from unittest.mock import Mock, patch, MagicMock
from PIL import Image
from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.base import (
Message,
MessageRole,
TextContent,
ImageContent,
ToolUseContent,
ToolResultContent,
ThinkingContent,
LLMSettings,
StreamEvent,
)
from memory.common.llms.tools import ToolDefinition
@pytest.fixture
def provider():
return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229")
@pytest.fixture
def thinking_provider():
return AnthropicProvider(
api_key="test-key", model="claude-opus-4", enable_thinking=True
)
def test_initialization(provider):
assert provider.api_key == "test-key"
assert provider.model == "claude-3-opus-20240229"
assert provider.enable_thinking is False
def test_client_lazy_loading(provider):
assert provider._client is None
client = provider.client
assert client is not None
assert provider._client is not None
# Second call should return same instance
assert provider.client is client
def test_async_client_lazy_loading(provider):
assert provider._async_client is None
client = provider.async_client
assert client is not None
assert provider._async_client is not None
@pytest.mark.parametrize(
"model, expected",
[
("claude-opus-4", True),
("claude-opus-4-1", True),
("claude-sonnet-4-0", True),
("claude-sonnet-3-7", True),
("claude-sonnet-4-5", True),
("claude-3-opus-20240229", False),
("claude-3-sonnet-20240229", False),
("gpt-4", False),
],
)
def test_supports_thinking(model, expected):
provider = AnthropicProvider(api_key="test-key", model=model)
assert provider._supports_thinking() == expected
def test_convert_text_content(provider):
content = TextContent(text="hello world")
result = provider._convert_text_content(content)
assert result == {"type": "text", "text": "hello world"}
def test_convert_image_content(provider):
image = Image.new("RGB", (100, 100), color="red")
content = ImageContent(image=image)
result = provider._convert_image_content(content)
assert result["type"] == "image"
assert result["source"]["type"] == "base64"
assert result["source"]["media_type"] == "image/jpeg"
assert isinstance(result["source"]["data"], str)
def test_should_include_message_filters_system(provider):
system_msg = Message(role=MessageRole.SYSTEM, content="system prompt")
user_msg = Message(role=MessageRole.USER, content="user message")
assert provider._should_include_message(system_msg) is False
assert provider._should_include_message(user_msg) is True
@pytest.mark.parametrize(
"messages, expected_count",
[
([Message(role=MessageRole.USER, content="test")], 1),
([Message(role=MessageRole.SYSTEM, content="test")], 0),
(
[
Message(role=MessageRole.SYSTEM, content="system"),
Message(role=MessageRole.USER, content="user"),
],
1,
),
],
)
def test_convert_messages(provider, messages, expected_count):
result = provider._convert_messages(messages)
assert len(result) == expected_count
def test_convert_tool(provider):
tool = ToolDefinition(
name="test_tool",
description="A test tool",
input_schema={"type": "object", "properties": {}},
function=lambda x: "result",
)
result = provider._convert_tool(tool)
assert result["name"] == "test_tool"
assert result["description"] == "A test tool"
assert result["input_schema"] == {"type": "object", "properties": {}}
def test_build_request_kwargs_basic(provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(temperature=0.5, max_tokens=1000)
kwargs = provider._build_request_kwargs(messages, None, None, settings)
assert kwargs["model"] == "claude-3-opus-20240229"
assert kwargs["temperature"] == 0.5
assert kwargs["max_tokens"] == 1000
assert len(kwargs["messages"]) == 1
def test_build_request_kwargs_with_system_prompt(provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
assert kwargs["system"] == "system prompt"
def test_build_request_kwargs_with_tools(provider):
messages = [Message(role=MessageRole.USER, content="test")]
tools = [
ToolDefinition(
name="test",
description="test",
input_schema={},
function=lambda x: "result",
)
]
settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
assert "tools" in kwargs
assert len(kwargs["tools"]) == 1
def test_build_request_kwargs_with_thinking(thinking_provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(max_tokens=5000)
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
assert "thinking" in kwargs
assert kwargs["thinking"]["type"] == "enabled"
assert kwargs["thinking"]["budget_tokens"] == 3976
assert kwargs["temperature"] == 1.0
assert "top_p" not in kwargs
def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(max_tokens=1000)
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
# Shouldn't enable thinking if not enough tokens
assert "thinking" not in kwargs
def test_handle_stream_event_text_delta(provider):
event = Mock(
type="content_block_delta",
delta=Mock(type="text_delta", text="hello"),
)
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is not None
assert stream_event.type == "text"
assert stream_event.data == "hello"
assert tool_use is None
def test_handle_stream_event_thinking_delta(provider):
event = Mock(
type="content_block_delta",
delta=Mock(type="thinking_delta", thinking="reasoning..."),
)
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is not None
assert stream_event.type == "thinking"
assert stream_event.data == "reasoning..."
def test_handle_stream_event_tool_use_start(provider):
block = Mock(spec=["type", "id", "name", "input"])
block.type = "tool_use"
block.id = "tool-1"
block.name = "test_tool"
block.input = {}
event = Mock(spec=["type", "content_block"])
event.type = "content_block_start"
event.content_block = block
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is None
assert tool_use is not None
assert tool_use["id"] == "tool-1"
assert tool_use["name"] == "test_tool"
assert tool_use["input"] == {}
def test_handle_stream_event_tool_input_delta(provider):
current_tool = {"id": "tool-1", "name": "test", "input": '{"ke'}
event = Mock(
type="content_block_delta",
delta=Mock(type="input_json_delta", partial_json='y": "val'),
)
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
assert stream_event is None
assert tool_use["input"] == '{"key": "val'
def test_handle_stream_event_tool_use_complete(provider):
current_tool = {
"id": "tool-1",
"name": "test_tool",
"input": '{"key": "value"}',
}
event = Mock(
type="content_block_stop",
content_block=Mock(input={"key": "value"}),
)
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
assert stream_event is not None
assert stream_event.type == "tool_use"
assert stream_event.data["id"] == "tool-1"
assert stream_event.data["name"] == "test_tool"
assert stream_event.data["input"] == {"key": "value"}
assert tool_use is None
def test_handle_stream_event_message_stop(provider):
event = Mock(type="message_stop")
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is not None
assert stream_event.type == "done"
assert tool_use is None
def test_handle_stream_event_error(provider):
event = Mock(type="error", error="API error")
stream_event, tool_use = provider._handle_stream_event(event, None)
assert stream_event is not None
assert stream_event.type == "error"
assert "API error" in stream_event.data
def test_generate_basic(provider, mock_anthropic_client):
messages = [Message(role=MessageRole.USER, content="test")]
# Mock the response properly
mock_block = Mock(spec=["type", "text"])
mock_block.type = "text"
mock_block.text = "test summary"
mock_response = Mock(spec=["content"])
mock_response.content = [mock_block]
provider.client.messages.create.return_value = mock_response
result = provider.generate(messages)
assert result == "test summary"
provider.client.messages.create.assert_called_once()
def test_stream_basic(provider, mock_anthropic_client):
messages = [Message(role=MessageRole.USER, content="test")]
events = list(provider.stream(messages))
# Should get text event and done event
assert len(events) > 0
assert any(e.type == "text" for e in events)
provider.client.messages.stream.assert_called_once()
@pytest.mark.asyncio
async def test_agenerate_basic(provider, mock_anthropic_client):
messages = [Message(role=MessageRole.USER, content="test")]
result = await provider.agenerate(messages)
assert result == "test summary"
provider.async_client.messages.create.assert_called_once()
@pytest.mark.asyncio
async def test_astream_basic(provider, mock_anthropic_client):
messages = [Message(role=MessageRole.USER, content="test")]
events = []
async for event in provider.astream(messages):
events.append(event)
assert len(events) > 0
assert any(e.type == "text" for e in events)
def test_convert_message_sorts_thinking_content(provider):
"""Thinking content should be sorted so non-thinking comes before thinking."""
message = Message.assistant(
ThinkingContent(thinking="reasoning", signature="sig"),
TextContent(text="response"),
)
result = provider._convert_message(message)
assert result["role"] == "assistant"
# The sort key (x["type"] != "thinking") sorts thinking type to beginning
# because "thinking" != "thinking" is False, which sorts before True
content_types = [c["type"] for c in result["content"]]
assert "text" in content_types
assert "thinking" in content_types
# Verify thinking comes before non-thinking (sorted by key)
thinking_idx = content_types.index("thinking")
text_idx = content_types.index("text")
assert thinking_idx < text_idx
def test_execute_tool_success(provider):
tool_call = {"id": "t1", "name": "test", "input": {"arg": "value"}}
tools = {
"test": ToolDefinition(
name="test",
description="test",
input_schema={},
function=lambda x: f"result: {x['arg']}",
)
}
result = provider.execute_tool(tool_call, tools)
assert result.tool_use_id == "t1"
assert result.content == "result: value"
assert result.is_error is False
def test_execute_tool_missing_name(provider):
tool_call = {"id": "t1", "input": {}}
tools = {}
result = provider.execute_tool(tool_call, tools)
assert result.tool_use_id == "t1"
assert "missing" in result.content.lower()
assert result.is_error is True
def test_execute_tool_not_found(provider):
tool_call = {"id": "t1", "name": "nonexistent", "input": {}}
tools = {}
result = provider.execute_tool(tool_call, tools)
assert result.tool_use_id == "t1"
assert "not found" in result.content.lower()
assert result.is_error is True
def test_execute_tool_exception(provider):
tool_call = {"id": "t1", "name": "test", "input": {}}
tools = {
"test": ToolDefinition(
name="test",
description="test",
input_schema={},
function=lambda x: 1 / 0, # Raises ZeroDivisionError
)
}
result = provider.execute_tool(tool_call, tools)
assert result.tool_use_id == "t1"
assert result.is_error is True
assert "division" in result.content.lower()
def test_encode_image(provider):
image = Image.new("RGB", (10, 10), color="blue")
encoded = provider.encode_image(image)
assert isinstance(encoded, str)
assert len(encoded) > 0
def test_encode_image_rgba(provider):
"""RGBA images should be converted to RGB."""
image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128))
encoded = provider.encode_image(image)
assert isinstance(encoded, str)
assert len(encoded) > 0

View File

@ -0,0 +1,270 @@
import pytest
from PIL import Image
from memory.common.llms.base import (
Message,
MessageRole,
TextContent,
ImageContent,
ToolUseContent,
ToolResultContent,
ThinkingContent,
LLMSettings,
StreamEvent,
create_provider,
)
from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.openai_provider import OpenAIProvider
from memory.common import settings
def test_message_role_enum():
assert MessageRole.SYSTEM == "system"
assert MessageRole.USER == "user"
assert MessageRole.ASSISTANT == "assistant"
assert MessageRole.TOOL == "tool"
def test_text_content_creation():
content = TextContent(text="hello")
assert content.type == "text"
assert content.text == "hello"
assert content.valid
def test_text_content_to_dict():
content = TextContent(text="hello")
result = content.to_dict()
assert result == {"type": "text", "text": "hello"}
def test_text_content_empty_invalid():
content = TextContent(text="")
assert not content.valid
def test_image_content_creation():
image = Image.new("RGB", (10, 10))
content = ImageContent(image=image)
assert content.type == "image"
assert content.image == image
assert content.valid
def test_image_content_with_detail():
image = Image.new("RGB", (10, 10))
content = ImageContent(image=image, detail="high")
assert content.detail == "high"
def test_tool_use_content_creation():
content = ToolUseContent(id="t1", name="test_tool", input={"arg": "value"})
assert content.type == "tool_use"
assert content.id == "t1"
assert content.name == "test_tool"
assert content.input == {"arg": "value"}
assert content.valid
def test_tool_use_content_to_dict():
content = ToolUseContent(id="t1", name="test", input={"key": "val"})
result = content.to_dict()
assert result == {
"type": "tool_use",
"id": "t1",
"name": "test",
"input": {"key": "val"},
}
def test_tool_result_content_creation():
content = ToolResultContent(
tool_use_id="t1",
content="result",
is_error=False,
)
assert content.type == "tool_result"
assert content.tool_use_id == "t1"
assert content.content == "result"
assert not content.is_error
assert content.valid
def test_tool_result_content_with_error():
content = ToolResultContent(
tool_use_id="t1",
content="error message",
is_error=True,
)
assert content.is_error
def test_thinking_content_creation():
content = ThinkingContent(thinking="reasoning...", signature="sig")
assert content.type == "thinking"
assert content.thinking == "reasoning..."
assert content.signature == "sig"
assert content.valid
def test_thinking_content_invalid_without_signature():
content = ThinkingContent(thinking="reasoning...")
assert not content.valid
def test_message_simple_string_content():
msg = Message(role=MessageRole.USER, content="hello")
assert msg.role == MessageRole.USER
assert msg.content == "hello"
def test_message_list_content():
content_list = [TextContent(text="hello"), TextContent(text="world")]
msg = Message(role=MessageRole.USER, content=content_list)
assert msg.role == MessageRole.USER
assert len(msg.content) == 2
def test_message_to_dict_string():
msg = Message(role=MessageRole.USER, content="hello")
result = msg.to_dict()
assert result == {"role": "user", "content": "hello"}
def test_message_to_dict_list():
msg = Message(
role=MessageRole.USER,
content=[TextContent(text="hello"), TextContent(text="world")],
)
result = msg.to_dict()
assert result["role"] == "user"
assert len(result["content"]) == 2
assert result["content"][0] == {"type": "text", "text": "hello"}
def test_message_assistant_factory():
msg = Message.assistant(
TextContent(text="response"),
ToolUseContent(id="t1", name="tool", input={}),
)
assert msg.role == MessageRole.ASSISTANT
assert len(msg.content) == 2
def test_message_assistant_filters_invalid_content():
msg = Message.assistant(
TextContent(text="valid"),
TextContent(text=""), # Invalid - empty
)
assert len(msg.content) == 1
assert msg.content[0].text == "valid"
def test_message_user_factory():
msg = Message.user(text="hello")
assert msg.role == MessageRole.USER
assert len(msg.content) == 1
assert isinstance(msg.content[0], TextContent)
def test_message_user_with_tool_result():
tool_result = ToolResultContent(tool_use_id="t1", content="result")
msg = Message.user(text="hello", tool_result=tool_result)
assert len(msg.content) == 2
def test_stream_event_creation():
event = StreamEvent(type="text", data="hello")
assert event.type == "text"
assert event.data == "hello"
def test_stream_event_with_signature():
event = StreamEvent(type="thinking", signature="sig123")
assert event.signature == "sig123"
def test_llm_settings_defaults():
settings = LLMSettings()
assert settings.temperature == 0.7
assert settings.max_tokens == 2048
assert settings.top_p is None
assert settings.stop_sequences is None
assert settings.stream is False
def test_llm_settings_custom():
settings = LLMSettings(
temperature=0.5,
max_tokens=1000,
top_p=0.9,
stop_sequences=["STOP"],
stream=True,
)
assert settings.temperature == 0.5
assert settings.max_tokens == 1000
assert settings.top_p == 0.9
assert settings.stop_sequences == ["STOP"]
assert settings.stream is True
def test_create_provider_anthropic():
provider = create_provider(
model="anthropic/claude-3-opus-20240229",
api_key="test-key",
)
assert isinstance(provider, AnthropicProvider)
assert provider.model == "claude-3-opus-20240229"
def test_create_provider_openai():
provider = create_provider(
model="openai/gpt-4o",
api_key="test-key",
)
assert isinstance(provider, OpenAIProvider)
assert provider.model == "gpt-4o"
def test_create_provider_unknown_raises():
with pytest.raises(ValueError, match="Unknown provider"):
create_provider(model="unknown/model", api_key="test-key")
def test_create_provider_uses_default_model():
"""If no model provided, should use SUMMARIZER_MODEL from settings."""
provider = create_provider(api_key="test-key")
# Should create a provider (type depends on settings.SUMMARIZER_MODEL)
assert provider is not None
def test_create_provider_anthropic_with_thinking():
provider = create_provider(
model="anthropic/claude-opus-4",
api_key="test-key",
enable_thinking=True,
)
assert isinstance(provider, AnthropicProvider)
assert provider.enable_thinking is True
def test_create_provider_missing_anthropic_key():
# Temporarily clear the API key from settings
original_key = settings.ANTHROPIC_API_KEY
try:
settings.ANTHROPIC_API_KEY = ""
with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"):
create_provider(model="anthropic/claude-3-opus-20240229")
finally:
settings.ANTHROPIC_API_KEY = original_key
def test_create_provider_missing_openai_key():
# Temporarily clear the API key from settings
original_key = settings.OPENAI_API_KEY
try:
settings.OPENAI_API_KEY = ""
with pytest.raises(ValueError, match="OPENAI_API_KEY"):
create_provider(model="openai/gpt-4o")
finally:
settings.OPENAI_API_KEY = original_key

View File

@ -0,0 +1,478 @@
"""Comprehensive tests for OpenAI stream chunk parsing."""
import pytest
from unittest.mock import Mock
from memory.common.llms.openai_provider import OpenAIProvider
from memory.common.llms.base import StreamEvent
@pytest.fixture
def provider():
return OpenAIProvider(api_key="test-key", model="gpt-4o")
# Text Content Tests
@pytest.mark.parametrize(
"content,expected_events",
[
("Hello", 1),
("", 0), # Empty string is falsy
(None, 0),
("Line 1\nLine 2\nLine 3", 1),
("Hello 世界 🌍", 1),
],
)
def test_text_content(provider, content, expected_events):
"""Text content should emit text events appropriately."""
delta = Mock(spec=["content", "tool_calls"])
delta.content = content
delta.tool_calls = None
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == expected_events
if expected_events > 0:
assert events[0].type == "text"
assert events[0].data == content
assert tool_call is None
# Tool Call Start Tests
def test_new_tool_call_basic(provider):
"""New tool call should initialize state."""
function = Mock(spec=["name", "arguments"])
function.name = "search"
function.arguments = ""
tool = Mock(spec=["id", "function"])
tool.id = "call_123"
tool.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = [tool]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 0
assert tool_call == {"id": "call_123", "name": "search", "arguments": ""}
@pytest.mark.parametrize(
"name,arguments,expected_name,expected_args",
[
("calculate", '{"operation":', "calculate", '{"operation":'),
(None, "", "", ""),
("test", None, "test", ""),
],
)
def test_new_tool_call_variations(
provider, name, arguments, expected_name, expected_args
):
"""Tool calls with various name/argument combinations."""
function = Mock(spec=["name", "arguments"])
function.name = name
function.arguments = arguments
tool = Mock(spec=["id", "function"])
tool.id = "call_123"
tool.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = [tool]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert tool_call["name"] == expected_name
assert tool_call["arguments"] == expected_args
def test_new_tool_call_replaces_previous(provider):
"""New tool call should finalize and replace previous."""
current = {"id": "call_old", "name": "old_tool", "arguments": '{"arg": "value"}'}
function = Mock(spec=["name", "arguments"])
function.name = "new_tool"
function.arguments = ""
tool = Mock(spec=["id", "function"])
tool.id = "call_new"
tool.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = [tool]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, current)
assert len(events) == 1
assert events[0].type == "tool_use"
assert events[0].data["id"] == "call_old"
assert events[0].data["input"] == {"arg": "value"}
assert tool_call["id"] == "call_new"
# Tool Call Continuation Tests
@pytest.mark.parametrize(
"initial_args,new_args,expected_args",
[
('{"query": "', 'test query"}', '{"query": "test query"}'),
('{"query"', ': "value"}', '{"query": "value"}'),
("", '{"full": "json"}', '{"full": "json"}'),
('{"partial"', "", '{"partial"'), # Empty doesn't accumulate
],
)
def test_tool_call_argument_accumulation(
provider, initial_args, new_args, expected_args
):
"""Arguments should accumulate correctly."""
current = {"id": "call_123", "name": "search", "arguments": initial_args}
function = Mock(spec=["name", "arguments"])
function.name = None
function.arguments = new_args
tool = Mock(spec=["id", "function"])
tool.id = None
tool.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = [tool]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, current)
assert len(events) == 0
assert tool_call["arguments"] == expected_args
def test_tool_call_accumulation_without_current_tool(provider):
"""Arguments without current tool should be ignored."""
function = Mock(spec=["name", "arguments"])
function.name = None
function.arguments = '{"arg": "value"}'
tool = Mock(spec=["id", "function"])
tool.id = None
tool.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = [tool]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 0
assert tool_call is None
def test_incremental_json_building(provider):
"""Test realistic incremental JSON building across multiple chunks."""
current = {"id": "c1", "name": "search", "arguments": ""}
increments = ['{"', 'query":', ' "test"}']
expected_states = ['{"', '{"query":', '{"query": "test"}']
for increment, expected in zip(increments, expected_states):
function = Mock(spec=["name", "arguments"])
function.name = None
function.arguments = increment
tool = Mock(spec=["id", "function"])
tool.id = None
tool.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = [tool]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
_, current = provider._handle_stream_chunk(chunk, current)
assert current["arguments"] == expected
# Finish Reason Tests
def test_finish_reason_without_tool(provider):
"""Stop finish without tool should not emit events."""
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = None
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = "stop"
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 0
assert tool_call is None
@pytest.mark.parametrize(
"arguments,expected_input",
[
('{"query": "test"}', {"query": "test"}),
('{"invalid": json}', {}),
("", {}),
],
)
def test_finish_reason_with_tool(provider, arguments, expected_input):
"""Finish with tool call should finalize and emit."""
current = {"id": "call_123", "name": "search", "arguments": arguments}
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = None
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = "tool_calls"
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, current)
assert len(events) == 1
assert events[0].type == "tool_use"
assert events[0].data["id"] == "call_123"
assert events[0].data["input"] == expected_input
assert tool_call is None
@pytest.mark.parametrize("reason", ["stop", "length", "content_filter", "tool_calls"])
def test_various_finish_reasons(provider, reason):
"""Various finish reasons with active tool should finalize."""
current = {"id": "call_123", "name": "test", "arguments": '{"a": 1}'}
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = None
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = reason
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, current)
assert len(events) == 1
assert tool_call is None
# Edge Cases Tests
def test_empty_choices(provider):
"""Empty choices list should return empty events."""
chunk = Mock(spec=["choices"])
chunk.choices = []
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 0
assert tool_call is None
def test_none_choices(provider):
"""None choices should be handled gracefully."""
chunk = Mock(spec=["choices"])
chunk.choices = None
try:
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 0
except (TypeError, AttributeError):
pass # Also acceptable for malformed input
def test_multiple_chunks_in_sequence(provider):
"""Test processing multiple chunks sequentially."""
# Chunk 1: Start
function1 = Mock(spec=["name", "arguments"])
function1.name = "search"
function1.arguments = ""
tool1 = Mock(spec=["id", "function"])
tool1.id = "call_1"
tool1.function = function1
delta1 = Mock(spec=["content", "tool_calls"])
delta1.content = None
delta1.tool_calls = [tool1]
choice1 = Mock(spec=["delta", "finish_reason"])
choice1.delta = delta1
choice1.finish_reason = None
chunk1 = Mock(spec=["choices"])
chunk1.choices = [choice1]
events1, state = provider._handle_stream_chunk(chunk1, None)
assert len(events1) == 0
assert state is not None
# Chunk 2: Args
function2 = Mock(spec=["name", "arguments"])
function2.name = None
function2.arguments = '{"q": "test"}'
tool2 = Mock(spec=["id", "function"])
tool2.id = None
tool2.function = function2
delta2 = Mock(spec=["content", "tool_calls"])
delta2.content = None
delta2.tool_calls = [tool2]
choice2 = Mock(spec=["delta", "finish_reason"])
choice2.delta = delta2
choice2.finish_reason = None
chunk2 = Mock(spec=["choices"])
chunk2.choices = [choice2]
events2, state = provider._handle_stream_chunk(chunk2, state)
assert len(events2) == 0
assert state["arguments"] == '{"q": "test"}'
# Chunk 3: Finish
delta3 = Mock(spec=["content", "tool_calls"])
delta3.content = None
delta3.tool_calls = None
choice3 = Mock(spec=["delta", "finish_reason"])
choice3.delta = delta3
choice3.finish_reason = "stop"
chunk3 = Mock(spec=["choices"])
chunk3.choices = [choice3]
events3, state = provider._handle_stream_chunk(chunk3, state)
assert len(events3) == 1
assert events3[0].type == "tool_use"
assert state is None
def test_text_and_tool_calls_mixed(provider):
"""Text content should be emitted before tool initialization."""
function = Mock(spec=["name", "arguments"])
function.name = "search"
function.arguments = ""
tool = Mock(spec=["id", "function"])
tool.id = "call_1"
tool.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = "Let me search for that."
delta.tool_calls = [tool]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 1
assert events[0].type == "text"
assert events[0].data == "Let me search for that."
assert tool_call is not None
# JSON Parsing Tests
@pytest.mark.parametrize(
"arguments,expected_input",
[
('{"key": "value", "num": 42}', {"key": "value", "num": 42}),
("{}", {}),
(
'{"user": {"name": "John", "tags": ["a", "b"]}, "count": 10}',
{"user": {"name": "John", "tags": ["a", "b"]}, "count": 10},
),
('{"invalid": json}', {}),
('{"key": "val', {}),
("", {}),
('{"text": "Hello 世界 🌍"}', {"text": "Hello 世界 🌍"}),
(
'{"text": "Line 1\\nLine 2\\t\\tTabbed"}',
{"text": "Line 1\nLine 2\t\tTabbed"},
),
],
)
def test_json_parsing(provider, arguments, expected_input):
"""Various JSON inputs should be parsed correctly."""
tool_call = {"id": "c1", "name": "test", "arguments": arguments}
result = provider._parse_and_finalize_tool_call(tool_call)
assert result["input"] == expected_input
assert "arguments" not in result

View File

@ -0,0 +1,561 @@
import pytest
from unittest.mock import Mock
from PIL import Image
from memory.common.llms.openai_provider import OpenAIProvider
from memory.common.llms.base import (
Message,
MessageRole,
TextContent,
ImageContent,
ToolUseContent,
ToolResultContent,
LLMSettings,
StreamEvent,
)
from memory.common.llms.tools import ToolDefinition
@pytest.fixture
def provider():
return OpenAIProvider(api_key="test-key", model="gpt-4o")
@pytest.fixture
def reasoning_provider():
return OpenAIProvider(api_key="test-key", model="o1-preview")
def test_initialization(provider):
assert provider.api_key == "test-key"
assert provider.model == "gpt-4o"
def test_client_lazy_loading(provider):
assert provider._client is None
client = provider.client
assert client is not None
assert provider._client is not None
def test_async_client_lazy_loading(provider):
assert provider._async_client is None
client = provider.async_client
assert client is not None
assert provider._async_client is not None
@pytest.mark.parametrize(
"model, expected",
[
("gpt-4o", False),
("o1-preview", True),
("o1-mini", True),
("gpt-4-turbo", True),
("gpt-3.5-turbo", True),
],
)
def test_is_reasoning_model(model, expected):
provider = OpenAIProvider(api_key="test-key", model=model)
assert provider._is_reasoning_model() == expected
def test_convert_text_content(provider):
content = TextContent(text="hello world")
result = provider._convert_text_content(content)
assert result == {"type": "text", "text": "hello world"}
def test_convert_image_content(provider):
image = Image.new("RGB", (100, 100), color="red")
content = ImageContent(image=image)
result = provider._convert_image_content(content)
assert result["type"] == "image_url"
assert "image_url" in result
assert result["image_url"]["url"].startswith("data:image/jpeg;base64,")
def test_convert_image_content_with_detail(provider):
image = Image.new("RGB", (100, 100), color="red")
content = ImageContent(image=image, detail="high")
result = provider._convert_image_content(content)
assert result["image_url"]["detail"] == "high"
def test_convert_tool_use_content(provider):
content = ToolUseContent(
id="t1",
name="test_tool",
input={"arg": "value"},
)
result = provider._convert_tool_use_content(content)
assert result["id"] == "t1"
assert result["type"] == "function"
assert result["function"]["name"] == "test_tool"
assert '{"arg": "value"}' in result["function"]["arguments"]
def test_convert_tool_result_content(provider):
content = ToolResultContent(
tool_use_id="t1",
content="result content",
is_error=False,
)
result = provider._convert_tool_result_content(content)
assert result["role"] == "tool"
assert result["tool_call_id"] == "t1"
assert result["content"] == "result content"
def test_convert_messages_simple(provider):
messages = [Message(role=MessageRole.USER, content="test")]
result = provider._convert_messages(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
assert result[0]["content"] == "test"
def test_convert_messages_with_tool_result(provider):
"""Tool results should become separate messages with 'tool' role."""
messages = [
Message(
role=MessageRole.USER,
content=[ToolResultContent(tool_use_id="t1", content="result")],
)
]
result = provider._convert_messages(messages)
assert len(result) == 1
assert result[0]["role"] == "tool"
assert result[0]["tool_call_id"] == "t1"
def test_convert_messages_with_tool_use(provider):
"""Tool use content should become tool_calls field."""
messages = [
Message.assistant(
TextContent(text="thinking..."),
ToolUseContent(id="t1", name="test", input={}),
)
]
result = provider._convert_messages(messages)
assert len(result) == 1
assert result[0]["role"] == "assistant"
assert "tool_calls" in result[0]
assert len(result[0]["tool_calls"]) == 1
def test_convert_messages_mixed_content(provider):
"""Messages with both text and tool results should be split."""
messages = [
Message(
role=MessageRole.USER,
content=[
TextContent(text="user text"),
ToolResultContent(tool_use_id="t1", content="result"),
],
)
]
result = provider._convert_messages(messages)
# Should create two messages: one user message and one tool message
assert len(result) == 2
assert result[0]["role"] == "tool"
assert result[1]["role"] == "user"
def test_convert_tools(provider):
tools = [
ToolDefinition(
name="test_tool",
description="A test tool",
input_schema={"type": "object", "properties": {"arg": {"type": "string"}}},
function=lambda x: "result",
)
]
result = provider._convert_tools(tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["function"]["name"] == "test_tool"
assert result[0]["function"]["description"] == "A test tool"
assert result[0]["function"]["parameters"] == tools[0].input_schema
def test_build_request_kwargs_basic(provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(temperature=0.5, max_tokens=1000)
kwargs = provider._build_request_kwargs(messages, None, None, settings)
assert kwargs["model"] == "gpt-4o"
assert kwargs["temperature"] == 0.5
assert kwargs["max_tokens"] == 1000
assert len(kwargs["messages"]) == 1
def test_build_request_kwargs_with_system_prompt_standard_model(provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
# For gpt-4o, system prompt becomes system message
assert kwargs["messages"][0]["role"] == "system"
assert kwargs["messages"][0]["content"] == "system prompt"
def test_build_request_kwargs_with_system_prompt_reasoning_model(
reasoning_provider,
):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings()
kwargs = reasoning_provider._build_request_kwargs(
messages, "system prompt", None, settings
)
# For o1 models, system prompt becomes developer message
assert kwargs["messages"][0]["role"] == "developer"
assert kwargs["messages"][0]["content"] == "system prompt"
def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens(
reasoning_provider,
):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(max_tokens=2000)
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
# Reasoning models use max_completion_tokens
assert "max_completion_tokens" in kwargs
assert kwargs["max_completion_tokens"] == 2000
assert "max_tokens" not in kwargs
def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(temperature=0.7)
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
# Reasoning models don't support temperature
assert "temperature" not in kwargs
assert "top_p" not in kwargs
def test_build_request_kwargs_with_tools(provider):
messages = [Message(role=MessageRole.USER, content="test")]
tools = [
ToolDefinition(
name="test",
description="test",
input_schema={},
function=lambda x: "result",
)
]
settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
assert "tools" in kwargs
assert len(kwargs["tools"]) == 1
assert kwargs["tool_choice"] == "auto"
def test_build_request_kwargs_with_stream(provider):
messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True)
assert kwargs["stream"] is True
def test_parse_and_finalize_tool_call(provider):
tool_call = {
"id": "t1",
"name": "test",
"arguments": '{"key": "value"}',
}
result = provider._parse_and_finalize_tool_call(tool_call)
assert result["id"] == "t1"
assert result["name"] == "test"
assert result["input"] == {"key": "value"}
assert "arguments" not in result
def test_parse_and_finalize_tool_call_invalid_json(provider):
tool_call = {
"id": "t1",
"name": "test",
"arguments": '{"invalid json',
}
result = provider._parse_and_finalize_tool_call(tool_call)
# Should default to empty dict on parse error
assert result["input"] == {}
def test_handle_stream_chunk_text_content(provider):
chunk = Mock(
choices=[
Mock(
delta=Mock(content="hello", tool_calls=None),
finish_reason=None,
)
]
)
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 1
assert events[0].type == "text"
assert events[0].data == "hello"
assert tool_call is None
def test_handle_stream_chunk_tool_call_start(provider):
function = Mock(spec=["name", "arguments"])
function.name = "test_tool"
function.arguments = ""
tool_call_mock = Mock(spec=["id", "function"])
tool_call_mock.id = "t1"
tool_call_mock.function = function
delta = Mock(spec=["content", "tool_calls"])
delta.content = None
delta.tool_calls = [tool_call_mock]
choice = Mock(spec=["delta", "finish_reason"])
choice.delta = delta
choice.finish_reason = None
chunk = Mock(spec=["choices"])
chunk.choices = [choice]
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 0
assert tool_call is not None
assert tool_call["id"] == "t1"
assert tool_call["name"] == "test_tool"
def test_handle_stream_chunk_tool_call_arguments(provider):
current_tool = {"id": "t1", "name": "test", "arguments": '{"ke'}
chunk = Mock(
choices=[
Mock(
delta=Mock(
content=None,
tool_calls=[
Mock(
id=None,
function=Mock(name=None, arguments='y": "val"}'),
)
],
),
finish_reason=None,
)
]
)
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
assert len(events) == 0
assert tool_call["arguments"] == '{"key": "val"}'
def test_handle_stream_chunk_finish_with_tool_call(provider):
current_tool = {"id": "t1", "name": "test", "arguments": '{"key": "value"}'}
chunk = Mock(
choices=[
Mock(
delta=Mock(content=None, tool_calls=None),
finish_reason="tool_calls",
)
]
)
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
assert len(events) == 1
assert events[0].type == "tool_use"
assert events[0].data["id"] == "t1"
assert events[0].data["input"] == {"key": "value"}
assert tool_call is None
def test_handle_stream_chunk_empty_choices(provider):
chunk = Mock(choices=[])
events, tool_call = provider._handle_stream_chunk(chunk, None)
assert len(events) == 0
assert tool_call is None
def test_generate_basic(provider, mock_openai_client):
messages = [Message(role=MessageRole.USER, content="test")]
# The conftest fixture already sets up the mock response
result = provider.generate(messages)
assert isinstance(result, str)
assert len(result) > 0
provider.client.chat.completions.create.assert_called_once()
def test_stream_basic(provider, mock_openai_client):
messages = [Message(role=MessageRole.USER, content="test")]
events = list(provider.stream(messages))
# Should get text events and done event
assert len(events) > 0
text_events = [e for e in events if e.type == "text"]
assert len(text_events) > 0
assert events[-1].type == "done"
@pytest.mark.asyncio
async def test_agenerate_basic(provider, mock_openai_client):
messages = [Message(role=MessageRole.USER, content="test")]
# Mock the async client
mock_response = Mock(choices=[Mock(message=Mock(content="async response"))])
provider.async_client.chat.completions.create = Mock(return_value=mock_response)
result = await provider.agenerate(messages)
assert result == "async response"
@pytest.mark.asyncio
async def test_astream_basic(provider, mock_openai_client):
messages = [Message(role=MessageRole.USER, content="test")]
# Mock async streaming
async def async_stream():
yield Mock(
choices=[
Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
]
)
yield Mock(
choices=[
Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
]
)
provider.async_client.chat.completions.create = Mock(return_value=async_stream())
events = []
async for event in provider.astream(messages):
events.append(event)
assert len(events) > 0
text_events = [e for e in events if e.type == "text"]
assert len(text_events) > 0
def test_stream_with_tool_call(provider, mock_openai_client):
"""Test streaming with a complete tool call."""
def stream_with_tool(*args, **kwargs):
if kwargs.get("stream"):
# First chunk - tool call start
function1 = Mock(spec=["name", "arguments"])
function1.name = "test_tool"
function1.arguments = ""
tool_call1 = Mock(spec=["id", "function"])
tool_call1.id = "t1"
tool_call1.function = function1
delta1 = Mock(spec=["content", "tool_calls"])
delta1.content = None
delta1.tool_calls = [tool_call1]
choice1 = Mock(spec=["delta", "finish_reason"])
choice1.delta = delta1
choice1.finish_reason = None
chunk1 = Mock(spec=["choices"])
chunk1.choices = [choice1]
# Second chunk - tool arguments
function2 = Mock(spec=["name", "arguments"])
function2.name = None
function2.arguments = '{"arg": "val"}'
tool_call2 = Mock(spec=["id", "function"])
tool_call2.id = None
tool_call2.function = function2
delta2 = Mock(spec=["content", "tool_calls"])
delta2.content = None
delta2.tool_calls = [tool_call2]
choice2 = Mock(spec=["delta", "finish_reason"])
choice2.delta = delta2
choice2.finish_reason = None
chunk2 = Mock(spec=["choices"])
chunk2.choices = [choice2]
# Third chunk - finish
delta3 = Mock(spec=["content", "tool_calls"])
delta3.content = None
delta3.tool_calls = None
choice3 = Mock(spec=["delta", "finish_reason"])
choice3.delta = delta3
choice3.finish_reason = "tool_calls"
chunk3 = Mock(spec=["choices"])
chunk3.choices = [choice3]
return iter([chunk1, chunk2, chunk3])
provider.client.chat.completions.create.side_effect = stream_with_tool
messages = [Message(role=MessageRole.USER, content="test")]
events = list(provider.stream(messages))
tool_events = [e for e in events if e.type == "tool_use"]
assert len(tool_events) == 1
assert tool_events[0].data["id"] == "t1"
assert tool_events[0].data["name"] == "test_tool"
assert tool_events[0].data["input"] == {"arg": "val"}
def test_encode_image(provider):
image = Image.new("RGB", (10, 10), color="blue")
encoded = provider.encode_image(image)
assert isinstance(encoded, str)
assert len(encoded) > 0
def test_encode_image_rgba(provider):
"""RGBA images should be converted to RGB."""
image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128))
encoded = provider.encode_image(image)
assert isinstance(encoded, str)
assert len(encoded) > 0

View File

@ -2,21 +2,41 @@
import argparse
from memory.common.db.connection import make_session
from memory.common.db.models.users import User
from memory.common.db.models.users import HumanUser, BotUser
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--email", type=str, required=True)
args.add_argument("--password", type=str, required=True)
args.add_argument("--name", type=str, required=True)
args.add_argument("--password", type=str, required=False)
args.add_argument("--bot", action="store_true", help="Create a bot user")
args.add_argument(
"--api-key",
type=str,
required=False,
help="API key for bot user (auto-generated if not provided)",
)
args = args.parse_args()
with make_session() as session:
user = User.create_with_password(
email=args.email, password=args.password, name=args.name
)
if args.bot:
user = BotUser.create_with_api_key(
name=args.name, email=args.email, api_key=args.api_key
)
print(f"Bot user {args.email} created with API key: {user.api_key}")
else:
if not args.password:
raise ValueError("Password required for human users")
user = HumanUser.create_with_password(
email=args.email, password=args.password, name=args.name
)
print(f"Human user {args.email} created")
session.add(user)
session.commit()
print(f"User {args.email} created")
if args.bot:
print(f"Bot user {args.email} created with API key: {user.api_key}")
else:
print(f"Human user {args.email} created")