mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-22 22:56:38 +02:00
discord integration
This commit is contained in:
parent
e68671deb4
commit
1606348d8b
@ -1,8 +1,8 @@
|
|||||||
"""add_discord_models
|
"""add_discord_models
|
||||||
|
|
||||||
Revision ID: a8c8e8b17179
|
Revision ID: 7c6169fba146
|
||||||
Revises: c86079073c1d
|
Revises: c86079073c1d
|
||||||
Create Date: 2025-10-12 22:28:27.856164
|
Create Date: 2025-10-13 14:21:01.080948
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "a8c8e8b17179"
|
revision: str = "7c6169fba146"
|
||||||
down_revision: Union[str, None] = "c86079073c1d"
|
down_revision: Union[str, None] = "c86079073c1d"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
@ -26,12 +26,6 @@ def upgrade() -> None:
|
|||||||
sa.Column("name", sa.Text(), nullable=False),
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
sa.Column("description", sa.Text(), nullable=True),
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
sa.Column("member_count", sa.Integer(), nullable=True),
|
sa.Column("member_count", sa.Integer(), nullable=True),
|
||||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
|
||||||
sa.Column(
|
|
||||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
|
||||||
),
|
|
||||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
|
||||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
|
||||||
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"created_at",
|
"created_at",
|
||||||
@ -45,6 +39,17 @@ def upgrade() -> None:
|
|||||||
server_default=sa.text("now()"),
|
server_default=sa.text("now()"),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
|
sa.Column(
|
||||||
|
"track_messages", sa.Boolean(), server_default="true", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("ignore_messages", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("summary", sa.Text(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index(
|
op.create_index(
|
||||||
@ -59,12 +64,6 @@ def upgrade() -> None:
|
|||||||
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||||
sa.Column("name", sa.Text(), nullable=False),
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
sa.Column("channel_type", sa.Text(), nullable=False),
|
sa.Column("channel_type", sa.Text(), nullable=False),
|
||||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
|
||||||
sa.Column(
|
|
||||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
|
||||||
),
|
|
||||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
|
||||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"created_at",
|
"created_at",
|
||||||
sa.DateTime(timezone=True),
|
sa.DateTime(timezone=True),
|
||||||
@ -77,6 +76,17 @@ def upgrade() -> None:
|
|||||||
server_default=sa.text("now()"),
|
server_default=sa.text("now()"),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
|
sa.Column(
|
||||||
|
"track_messages", sa.Boolean(), server_default="true", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("ignore_messages", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("summary", sa.Text(), nullable=True),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["server_id"],
|
["server_id"],
|
||||||
["discord_servers.id"],
|
["discord_servers.id"],
|
||||||
@ -92,12 +102,6 @@ def upgrade() -> None:
|
|||||||
sa.Column("username", sa.Text(), nullable=False),
|
sa.Column("username", sa.Text(), nullable=False),
|
||||||
sa.Column("display_name", sa.Text(), nullable=True),
|
sa.Column("display_name", sa.Text(), nullable=True),
|
||||||
sa.Column("system_user_id", sa.Integer(), nullable=True),
|
sa.Column("system_user_id", sa.Integer(), nullable=True),
|
||||||
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
|
|
||||||
sa.Column(
|
|
||||||
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
|
|
||||||
),
|
|
||||||
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
|
||||||
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
|
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"created_at",
|
"created_at",
|
||||||
sa.DateTime(timezone=True),
|
sa.DateTime(timezone=True),
|
||||||
@ -110,6 +114,17 @@ def upgrade() -> None:
|
|||||||
server_default=sa.text("now()"),
|
server_default=sa.text("now()"),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
|
sa.Column(
|
||||||
|
"track_messages", sa.Boolean(), server_default="true", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("ignore_messages", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"allowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"disallowed_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("summary", sa.Text(), nullable=True),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["system_user_id"],
|
["system_user_id"],
|
||||||
["users.id"],
|
["users.id"],
|
145
db/migrations/versions/20251020_014858_seperate_user__models.py
Normal file
145
db/migrations/versions/20251020_014858_seperate_user__models.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
"""seperate_user__models
|
||||||
|
|
||||||
|
Revision ID: 35a2c1b610b6
|
||||||
|
Revises: 7c6169fba146
|
||||||
|
Create Date: 2025-10-20 01:48:58.537881
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "35a2c1b610b6"
|
||||||
|
down_revision: Union[str, None] = "7c6169fba146"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"discord_message", sa.Column("from_id", sa.BigInteger(), nullable=False)
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"discord_message", sa.Column("recipient_id", sa.BigInteger(), nullable=False)
|
||||||
|
)
|
||||||
|
op.drop_index("discord_message_user_idx", table_name="discord_message")
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_from_idx", "discord_message", ["from_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_recipient_idx",
|
||||||
|
"discord_message",
|
||||||
|
["recipient_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
"discord_message_discord_user_id_fkey", "discord_message", type_="foreignkey"
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"discord_message_from_id_fkey",
|
||||||
|
"discord_message",
|
||||||
|
"discord_users",
|
||||||
|
["from_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"discord_message_recipient_id_fkey",
|
||||||
|
"discord_message",
|
||||||
|
"discord_users",
|
||||||
|
["recipient_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
op.drop_column("discord_message", "discord_user_id")
|
||||||
|
op.add_column(
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
sa.Column("discord_channel_id", sa.BigInteger(), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
sa.Column("discord_user_id", sa.BigInteger(), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"scheduled_llm_calls_discord_user_id_fkey",
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
"discord_users",
|
||||||
|
["discord_user_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"scheduled_llm_calls_discord_channel_id_fkey",
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
"discord_channels",
|
||||||
|
["discord_channel_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
op.drop_column("scheduled_llm_calls", "discord_user")
|
||||||
|
op.drop_column("scheduled_llm_calls", "discord_channel")
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("user_type", sa.String(), nullable=False, server_default="human"),
|
||||||
|
)
|
||||||
|
op.add_column("users", sa.Column("api_key", sa.String(), nullable=True))
|
||||||
|
op.alter_column("users", "password_hash", existing_type=sa.VARCHAR(), nullable=True)
|
||||||
|
op.create_unique_constraint("users_api_key_key", "users", ["api_key"])
|
||||||
|
op.drop_column("users", "discord_user_id")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("discord_user_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
)
|
||||||
|
op.drop_constraint("users_api_key_key", "users", type_="unique")
|
||||||
|
op.alter_column(
|
||||||
|
"users", "password_hash", existing_type=sa.VARCHAR(), nullable=False
|
||||||
|
)
|
||||||
|
op.drop_column("users", "api_key")
|
||||||
|
op.drop_column("users", "user_type")
|
||||||
|
op.add_column(
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
sa.Column("discord_channel", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
sa.Column("discord_user", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
"scheduled_llm_calls_discord_user_id_fkey",
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
type_="foreignkey",
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
"scheduled_llm_calls_discord_channel_id_fkey",
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
type_="foreignkey",
|
||||||
|
)
|
||||||
|
op.drop_column("scheduled_llm_calls", "discord_user_id")
|
||||||
|
op.drop_column("scheduled_llm_calls", "discord_channel_id")
|
||||||
|
op.add_column(
|
||||||
|
"discord_message",
|
||||||
|
sa.Column("discord_user_id", sa.BIGINT(), autoincrement=False, nullable=False),
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
"discord_message_from_id_fkey", "discord_message", type_="foreignkey"
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
"discord_message_recipient_id_fkey", "discord_message", type_="foreignkey"
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"discord_message_discord_user_id_fkey",
|
||||||
|
"discord_message",
|
||||||
|
"discord_users",
|
||||||
|
["discord_user_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
op.drop_index("discord_message_recipient_idx", table_name="discord_message")
|
||||||
|
op.drop_index("discord_message_from_idx", table_name="discord_message")
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.drop_column("discord_message", "recipient_id")
|
||||||
|
op.drop_column("discord_message", "from_id")
|
@ -50,6 +50,8 @@ x-worker-base: &worker-base
|
|||||||
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||||
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||||
|
DISCORD_COLLECTOR_SERVER_URL: ingest-hub
|
||||||
|
DISCORD_COLLECTOR_PORT: 8003
|
||||||
secrets: [ postgres_password, openai_key, anthropic_key, ssh_private_key, ssh_public_key, ssh_known_hosts ]
|
secrets: [ postgres_password, openai_key, anthropic_key, ssh_private_key, ssh_public_key, ssh_known_hosts ]
|
||||||
read_only: true
|
read_only: true
|
||||||
tmpfs:
|
tmpfs:
|
||||||
@ -183,7 +185,6 @@ services:
|
|||||||
dockerfile: docker/ingest_hub/Dockerfile
|
dockerfile: docker/ingest_hub/Dockerfile
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
DISCORD_API_PORT: 8000
|
|
||||||
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
|
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
|
||||||
DISCORD_NOTIFICATIONS_ENABLED: true
|
DISCORD_NOTIFICATIONS_ENABLED: true
|
||||||
DISCORD_COLLECTOR_ENABLED: true
|
DISCORD_COLLECTOR_ENABLED: true
|
||||||
|
@ -25,7 +25,7 @@ from memory.api.MCP.oauth_provider import (
|
|||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import OAuthState, UserSession
|
from memory.common.db.models import OAuthState, UserSession
|
||||||
from memory.common.db.models.users import User
|
from memory.common.db.models.users import HumanUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -126,7 +126,11 @@ async def handle_login(request: Request):
|
|||||||
key: value for key, value in form.items() if key not in ["email", "password"]
|
key: value for key, value in form.items() if key not in ["email", "password"]
|
||||||
}
|
}
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
user = session.query(User).filter(User.email == form.get("email")).first()
|
user = (
|
||||||
|
session.query(HumanUser)
|
||||||
|
.filter(HumanUser.email == form.get("email"))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not user or not user.is_valid_password(str(form.get("password", ""))):
|
if not user or not user.is_valid_password(str(form.get("password", ""))):
|
||||||
logger.warning("Login failed - invalid credentials")
|
logger.warning("Login failed - invalid credentials")
|
||||||
return login_form(request, oauth_params, "Invalid email or password")
|
return login_form(request, oauth_params, "Invalid email or password")
|
||||||
@ -144,11 +148,7 @@ def get_current_user() -> dict:
|
|||||||
return {"authenticated": False}
|
return {"authenticated": False}
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
user_session = (
|
user_session = session.query(UserSession).get(access_token.token)
|
||||||
session.query(UserSession)
|
|
||||||
.filter(UserSession.id == access_token.token)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_session and user_session.user:
|
if user_session and user_session.user:
|
||||||
user_info = user_session.user.serialize()
|
user_info = user_session.user.serialize()
|
||||||
|
@ -21,6 +21,7 @@ from memory.common.db.models.users import (
|
|||||||
OAuthRefreshToken,
|
OAuthRefreshToken,
|
||||||
OAuthState,
|
OAuthState,
|
||||||
User,
|
User,
|
||||||
|
BotUser,
|
||||||
UserSession,
|
UserSession,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.users import (
|
from memory.common.db.models.users import (
|
||||||
@ -92,7 +93,7 @@ def create_oauth_token(
|
|||||||
"""Create an OAuth token response."""
|
"""Create an OAuth token response."""
|
||||||
return OAuthToken(
|
return OAuthToken(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
token_type="bearer",
|
token_type="Bearer",
|
||||||
expires_in=ACCESS_TOKEN_LIFETIME,
|
expires_in=ACCESS_TOKEN_LIFETIME,
|
||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
scope=" ".join(scopes),
|
scope=" ".join(scopes),
|
||||||
@ -310,26 +311,37 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||||
"""Load and validate an access token."""
|
"""Load and validate an access token (or bot API key)."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
now = datetime.now(timezone.utc).replace(
|
# Try as OAuth access token first
|
||||||
tzinfo=None
|
|
||||||
) # Make naive for DB comparison
|
|
||||||
|
|
||||||
# Query for active (non-expired) session
|
|
||||||
user_session = session.query(UserSession).get(token)
|
user_session = session.query(UserSession).get(token)
|
||||||
if not user_session:
|
if user_session:
|
||||||
return None
|
now = datetime.now(timezone.utc).replace(
|
||||||
|
tzinfo=None
|
||||||
|
) # Make naive for DB comparison
|
||||||
|
|
||||||
if user_session.expires_at < now:
|
if user_session.expires_at < now:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AccessToken(
|
return AccessToken(
|
||||||
token=token,
|
token=token,
|
||||||
client_id=user_session.oauth_state.client_id,
|
client_id=user_session.oauth_state.client_id,
|
||||||
scopes=user_session.oauth_state.scopes,
|
scopes=user_session.oauth_state.scopes,
|
||||||
expires_at=int(user_session.expires_at.timestamp()),
|
expires_at=int(user_session.expires_at.timestamp()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Try as bot API key
|
||||||
|
bot = session.query(User).filter(User.api_key == token).first()
|
||||||
|
if bot:
|
||||||
|
logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
|
||||||
|
return AccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id=cast(str, bot.name or bot.email),
|
||||||
|
scopes=["read", "write"], # Bots get full access
|
||||||
|
expires_at=2147483647, # Far future (2038)
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def load_refresh_token(
|
async def load_refresh_token(
|
||||||
self, client: OAuthClientInformationFull, refresh_token: str
|
self, client: OAuthClientInformationFull, refresh_token: str
|
||||||
|
@ -9,7 +9,9 @@ from typing import Any
|
|||||||
from memory.api.MCP.base import get_current_user
|
from memory.api.MCP.base import get_current_user
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import ScheduledLLMCall
|
from memory.common.db.models import ScheduledLLMCall
|
||||||
|
from memory.common.db.models.discord import DiscordChannel, DiscordUser
|
||||||
from memory.api.MCP.base import mcp
|
from memory.api.MCP.base import mcp
|
||||||
|
from memory.discord.schedule import schedule_discord_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -17,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def schedule_message(
|
async def schedule_message(
|
||||||
scheduled_time: str,
|
scheduled_time: str,
|
||||||
message: str | None = None,
|
message: str,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
topic: str | None = None,
|
topic: str | None = None,
|
||||||
discord_channel: str | None = None,
|
discord_channel: str | None = None,
|
||||||
@ -56,7 +58,8 @@ async def schedule_message(
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
raise ValueError("User not found")
|
raise ValueError("User not found")
|
||||||
|
|
||||||
discord_user = current_user.get("user", {}).get("discord_user_id")
|
discord_users = current_user.get("user", {}).get("discord_users")
|
||||||
|
discord_user = discord_users and next(iter(discord_users.keys()), None)
|
||||||
if not discord_user and not discord_channel:
|
if not discord_user and not discord_channel:
|
||||||
raise ValueError("Either discord_user or discord_channel must be provided")
|
raise ValueError("Either discord_user or discord_channel must be provided")
|
||||||
|
|
||||||
@ -69,27 +72,20 @@ async def schedule_message(
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("Invalid datetime format")
|
raise ValueError("Invalid datetime format")
|
||||||
|
|
||||||
# Validate that the scheduled time is in the future
|
|
||||||
# Compare with naive datetime since we store naive in the database
|
|
||||||
current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
||||||
if scheduled_dt <= current_time_naive:
|
|
||||||
raise ValueError("Scheduled time must be in the future")
|
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
# Create the scheduled call
|
scheduled_call = schedule_discord_message(
|
||||||
scheduled_call = ScheduledLLMCall(
|
session=session,
|
||||||
user_id=user_id,
|
|
||||||
scheduled_time=scheduled_dt,
|
scheduled_time=scheduled_dt,
|
||||||
message=message,
|
message=message,
|
||||||
topic=topic,
|
user_id=current_user.get("user", {}).get("user_id"),
|
||||||
model=model,
|
model=model,
|
||||||
system_prompt=system_prompt,
|
topic=topic,
|
||||||
discord_channel=discord_channel,
|
discord_channel=discord_channel,
|
||||||
discord_user=discord_user,
|
discord_user=discord_user,
|
||||||
data=metadata or {},
|
system_prompt=system_prompt,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(scheduled_call)
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -7,7 +7,7 @@ from memory.common import settings
|
|||||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||||
|
|
||||||
from memory.common.db.connection import get_session, make_session
|
from memory.common.db.connection import get_session, make_session
|
||||||
from memory.common.db.models.users import User, UserSession
|
from memory.common.db.models.users import User, HumanUser, BotUser, UserSession
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -91,14 +91,14 @@ def get_current_user(request: Request, db: DBSession = Depends(get_session)) ->
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def create_user(email: str, password: str, name: str, db: DBSession) -> User:
|
def create_user(email: str, password: str, name: str, db: DBSession) -> HumanUser:
|
||||||
"""Create a new user"""
|
"""Create a new human user"""
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
existing_user = db.query(User).filter(User.email == email).first()
|
existing_user = db.query(User).filter(User.email == email).first()
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise HTTPException(status_code=400, detail="User already exists")
|
raise HTTPException(status_code=400, detail="User already exists")
|
||||||
|
|
||||||
user = User.create_with_password(email, name, password)
|
user = HumanUser.create_with_password(email, name, password)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
@ -106,14 +106,19 @@ def create_user(email: str, password: str, name: str, db: DBSession) -> User:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(email: str, password: str, db: DBSession) -> User | None:
|
def authenticate_user(email: str, password: str, db: DBSession) -> HumanUser | None:
|
||||||
"""Authenticate a user by email and password"""
|
"""Authenticate a human user by email and password"""
|
||||||
user = db.query(User).filter(User.email == email).first()
|
user = db.query(HumanUser).filter(HumanUser.email == email).first()
|
||||||
if user and user.is_valid_password(password):
|
if user and user.is_valid_password(password):
|
||||||
return user
|
return user
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_bot(api_key: str, db: DBSession) -> BotUser | None:
|
||||||
|
"""Authenticate a bot by API key"""
|
||||||
|
return db.query(BotUser).filter(BotUser.api_key == api_key).first()
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/logout", methods=["GET", "POST"])
|
@router.api_route("/logout", methods=["GET", "POST"])
|
||||||
def logout(request: Request, db: DBSession = Depends(get_session)):
|
def logout(request: Request, db: DBSession = Depends(get_session)):
|
||||||
"""Logout and clear session"""
|
"""Logout and clear session"""
|
||||||
|
@ -30,6 +30,11 @@ from memory.common.db.models.source_items import (
|
|||||||
NotePayload,
|
NotePayload,
|
||||||
ForumPostPayload,
|
ForumPostPayload,
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.discord import (
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
|
)
|
||||||
from memory.common.db.models.observations import (
|
from memory.common.db.models.observations import (
|
||||||
ObservationContradiction,
|
ObservationContradiction,
|
||||||
ReactionPattern,
|
ReactionPattern,
|
||||||
@ -41,12 +46,12 @@ from memory.common.db.models.sources import (
|
|||||||
Book,
|
Book,
|
||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
DiscordServer,
|
|
||||||
DiscordChannel,
|
|
||||||
DiscordUser,
|
|
||||||
)
|
)
|
||||||
from memory.common.db.models.users import (
|
from memory.common.db.models.users import (
|
||||||
User,
|
User,
|
||||||
|
HumanUser,
|
||||||
|
BotUser,
|
||||||
|
DiscordBotUser,
|
||||||
UserSession,
|
UserSession,
|
||||||
OAuthClientInformation,
|
OAuthClientInformation,
|
||||||
OAuthState,
|
OAuthState,
|
||||||
@ -103,6 +108,9 @@ __all__ = [
|
|||||||
"DiscordUser",
|
"DiscordUser",
|
||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
|
"HumanUser",
|
||||||
|
"BotUser",
|
||||||
|
"DiscordBotUser",
|
||||||
"UserSession",
|
"UserSession",
|
||||||
"OAuthClientInformation",
|
"OAuthClientInformation",
|
||||||
"OAuthState",
|
"OAuthState",
|
||||||
|
121
src/memory/common/db/models/discord.py
Normal file
121
src/memory/common/db/models/discord.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
Database models for the Discord system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class MessageProcessor:
|
||||||
|
track_messages = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
ignore_messages = Column(Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
|
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
summary = Column(
|
||||||
|
Text,
|
||||||
|
nullable=True,
|
||||||
|
doc=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
A summary of this processor, made by and for AI systems.
|
||||||
|
|
||||||
|
The idea here is that AI systems can use this summary to keep notes on the given processor.
|
||||||
|
These should automatically be injected into the context of the messages that are processed by this processor.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_xml(self) -> str:
|
||||||
|
return (
|
||||||
|
textwrap.dedent("""
|
||||||
|
<{type}>
|
||||||
|
<name>{name}</name>
|
||||||
|
<summary>{summary}</summary>
|
||||||
|
</{type}>
|
||||||
|
""")
|
||||||
|
.format(
|
||||||
|
type=self.__class__.__tablename__[8:], # type: ignore
|
||||||
|
name=getattr(self, "name", None) or getattr(self, "username", None),
|
||||||
|
summary=self.summary,
|
||||||
|
)
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordServer(Base, MessageProcessor):
|
||||||
|
"""Discord server configuration and metadata"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_servers"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
description = Column(Text)
|
||||||
|
member_count = Column(Integer)
|
||||||
|
|
||||||
|
# Collection settings
|
||||||
|
last_sync_at = Column(DateTime(timezone=True))
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
channels = relationship(
|
||||||
|
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordChannel(Base, MessageProcessor):
|
||||||
|
"""Discord channel metadata and configuration"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_channels"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
|
||||||
|
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
|
||||||
|
|
||||||
|
# Collection settings (null = inherit from server)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
server = relationship("DiscordServer", back_populates="channels")
|
||||||
|
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordUser(Base, MessageProcessor):
|
||||||
|
"""Discord user metadata and preferences"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_users"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
|
||||||
|
username = Column(Text, nullable=False)
|
||||||
|
display_name = Column(Text)
|
||||||
|
|
||||||
|
# Link to system user if registered
|
||||||
|
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||||
|
|
||||||
|
# Basic DM settings
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
system_user = relationship("User", back_populates="discord_users")
|
||||||
|
|
||||||
|
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
@ -7,6 +7,7 @@ from sqlalchemy import (
|
|||||||
String,
|
String,
|
||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
|
BigInteger,
|
||||||
JSON,
|
JSON,
|
||||||
Text,
|
Text,
|
||||||
)
|
)
|
||||||
@ -37,8 +38,10 @@ class ScheduledLLMCall(Base):
|
|||||||
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
|
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
|
||||||
|
|
||||||
# Discord configuration
|
# Discord configuration
|
||||||
discord_channel = Column(String, nullable=True)
|
discord_channel_id = Column(
|
||||||
discord_user = Column(String, nullable=True)
|
BigInteger, ForeignKey("discord_channels.id"), nullable=True
|
||||||
|
)
|
||||||
|
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=True)
|
||||||
|
|
||||||
# Execution status and results
|
# Execution status and results
|
||||||
status = Column(
|
status = Column(
|
||||||
@ -55,6 +58,8 @@ class ScheduledLLMCall(Base):
|
|||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user = relationship("User")
|
user = relationship("User")
|
||||||
|
discord_channel = relationship("DiscordChannel", foreign_keys=[discord_channel_id])
|
||||||
|
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
|
||||||
|
|
||||||
def serialize(self) -> Dict[str, Any]:
|
def serialize(self) -> Dict[str, Any]:
|
||||||
def print_datetime(dt: datetime | None) -> str | None:
|
def print_datetime(dt: datetime | None) -> str | None:
|
||||||
@ -73,8 +78,8 @@ class ScheduledLLMCall(Base):
|
|||||||
"message": self.message,
|
"message": self.message,
|
||||||
"system_prompt": self.system_prompt,
|
"system_prompt": self.system_prompt,
|
||||||
"allowed_tools": self.allowed_tools,
|
"allowed_tools": self.allowed_tools,
|
||||||
"discord_channel": self.discord_channel,
|
"discord_channel": self.discord_channel and self.discord_channel.name,
|
||||||
"discord_user": self.discord_user,
|
"discord_user": self.discord_user and self.discord_user.username,
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
"response": self.response,
|
"response": self.response,
|
||||||
"error_message": self.error_message,
|
"error_message": self.error_message,
|
||||||
|
@ -286,7 +286,8 @@ class DiscordMessage(SourceItem):
|
|||||||
sent_at = Column(DateTime(timezone=True), nullable=False)
|
sent_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||||
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
|
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
|
||||||
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
|
from_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
|
||||||
|
recipient_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
|
||||||
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
|
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
|
||||||
|
|
||||||
# Discord-specific metadata
|
# Discord-specific metadata
|
||||||
@ -303,11 +304,33 @@ class DiscordMessage(SourceItem):
|
|||||||
|
|
||||||
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
||||||
server = relationship("DiscordServer", foreign_keys=[server_id])
|
server = relationship("DiscordServer", foreign_keys=[server_id])
|
||||||
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
|
from_user = relationship("DiscordUser", foreign_keys=[from_id])
|
||||||
|
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed_tools(self) -> list[str]:
|
||||||
|
return (
|
||||||
|
(self.channel.allowed_tools if self.channel else [])
|
||||||
|
+ (self.from_user.allowed_tools if self.from_user else [])
|
||||||
|
+ (self.server.allowed_tools if self.server else [])
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def disallowed_tools(self) -> list[str]:
|
||||||
|
return (
|
||||||
|
(self.channel.disallowed_tools if self.channel else [])
|
||||||
|
+ (self.from_user.disallowed_tools if self.from_user else [])
|
||||||
|
+ (self.server.disallowed_tools if self.server else [])
|
||||||
|
)
|
||||||
|
|
||||||
|
def tool_allowed(self, tool: str) -> bool:
|
||||||
|
return not (self.disallowed_tools and tool in self.disallowed_tools) and (
|
||||||
|
not self.allowed_tools or tool in self.allowed_tools
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return f"{self.discord_user.username}: {self.content}"
|
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "discord_message",
|
"polymorphic_identity": "discord_message",
|
||||||
@ -320,7 +343,8 @@ class DiscordMessage(SourceItem):
|
|||||||
"server_id",
|
"server_id",
|
||||||
"channel_id",
|
"channel_id",
|
||||||
),
|
),
|
||||||
Index("discord_message_user_idx", "discord_user_id"),
|
Index("discord_message_from_idx", "from_id"),
|
||||||
|
Index("discord_message_recipient_idx", "recipient_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
|
@ -125,74 +125,3 @@ class EmailAccount(Base):
|
|||||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MessageProcessor:
|
|
||||||
track_messages = Column(Boolean, nullable=False, server_default="true")
|
|
||||||
ignore_messages = Column(Boolean, nullable=True, default=False)
|
|
||||||
|
|
||||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordServer(Base, MessageProcessor):
|
|
||||||
"""Discord server configuration and metadata"""
|
|
||||||
|
|
||||||
__tablename__ = "discord_servers"
|
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
|
|
||||||
name = Column(Text, nullable=False)
|
|
||||||
description = Column(Text)
|
|
||||||
member_count = Column(Integer)
|
|
||||||
|
|
||||||
# Collection settings
|
|
||||||
last_sync_at = Column(DateTime(timezone=True))
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
|
|
||||||
channels = relationship(
|
|
||||||
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(Base, MessageProcessor):
|
|
||||||
"""Discord channel metadata and configuration"""
|
|
||||||
|
|
||||||
__tablename__ = "discord_channels"
|
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
|
|
||||||
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
|
||||||
name = Column(Text, nullable=False)
|
|
||||||
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
|
|
||||||
|
|
||||||
# Collection settings (null = inherit from server)
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
|
|
||||||
server = relationship("DiscordServer", back_populates="channels")
|
|
||||||
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordUser(Base, MessageProcessor):
|
|
||||||
"""Discord user metadata and preferences"""
|
|
||||||
|
|
||||||
__tablename__ = "discord_users"
|
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
|
|
||||||
username = Column(Text, nullable=False)
|
|
||||||
display_name = Column(Text)
|
|
||||||
|
|
||||||
# Link to system user if registered
|
|
||||||
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
|
||||||
|
|
||||||
# Basic DM settings
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
|
|
||||||
system_user = relationship("User", back_populates="discord_users")
|
|
||||||
|
|
||||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
|
||||||
|
@ -2,7 +2,6 @@ import hashlib
|
|||||||
import secrets
|
import secrets
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
@ -14,6 +13,7 @@ from sqlalchemy import (
|
|||||||
Boolean,
|
Boolean,
|
||||||
ARRAY,
|
ARRAY,
|
||||||
Numeric,
|
Numeric,
|
||||||
|
CheckConstraint,
|
||||||
)
|
)
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
@ -36,12 +36,21 @@ def verify_password(password: str, password_hash: str) -> bool:
|
|||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
__table_args__ = (
|
||||||
|
CheckConstraint(
|
||||||
|
"password_hash IS NOT NULL OR api_key IS NOT NULL",
|
||||||
|
name="user_has_auth_method",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
email = Column(String, nullable=False, unique=True)
|
email = Column(String, nullable=False, unique=True)
|
||||||
password_hash = Column(String, nullable=False)
|
user_type = Column(String, nullable=False) # Discriminator column
|
||||||
discord_user_id = Column(String, nullable=True)
|
|
||||||
|
# Make these nullable since subclasses will use them selectively
|
||||||
|
password_hash = Column(String, nullable=True)
|
||||||
|
api_key = Column(String, nullable=True, unique=True)
|
||||||
|
|
||||||
# Relationship to sessions
|
# Relationship to sessions
|
||||||
sessions = relationship(
|
sessions = relationship(
|
||||||
@ -52,22 +61,86 @@ class User(Base):
|
|||||||
)
|
)
|
||||||
discord_users = relationship("DiscordUser", back_populates="system_user")
|
discord_users = relationship("DiscordUser", back_populates="system_user")
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_on": user_type,
|
||||||
|
"polymorphic_identity": "user",
|
||||||
|
}
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
def serialize(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"user_id": self.id,
|
"user_id": self.id,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"email": self.email,
|
"email": self.email,
|
||||||
"discord_user_id": self.discord_user_id,
|
"user_type": self.user_type,
|
||||||
|
"discord_users": {
|
||||||
|
discord_user.id: discord_user.username
|
||||||
|
for discord_user in self.discord_users
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HumanUser(User):
|
||||||
|
"""Human user with password authentication"""
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "human",
|
||||||
|
}
|
||||||
|
|
||||||
def is_valid_password(self, password: str) -> bool:
|
def is_valid_password(self, password: str) -> bool:
|
||||||
"""Check if the provided password is valid for this user"""
|
"""Check if the provided password is valid for this user"""
|
||||||
return verify_password(password, cast(str, self.password_hash))
|
return verify_password(password, cast(str, self.password_hash))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_with_password(cls, email: str, name: str, password: str) -> "User":
|
def create_with_password(cls, email: str, name: str, password: str) -> "HumanUser":
|
||||||
"""Create a new user with a hashed password"""
|
"""Create a new human user with a hashed password"""
|
||||||
return cls(email=email, name=name, password_hash=hash_password(password))
|
return cls(
|
||||||
|
email=email,
|
||||||
|
name=name,
|
||||||
|
password_hash=hash_password(password),
|
||||||
|
user_type="human",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BotUser(User):
|
||||||
|
"""Bot user with API key authentication"""
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "bot",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_with_api_key(
|
||||||
|
cls, name: str, email: str, api_key: str | None = None
|
||||||
|
) -> "BotUser":
|
||||||
|
"""Create a new bot user with an API key"""
|
||||||
|
if api_key is None:
|
||||||
|
api_key = f"bot_{secrets.token_hex(32)}"
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
email=email,
|
||||||
|
api_key=api_key,
|
||||||
|
user_type=cls.__mapper_args__["polymorphic_identity"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordBotUser(BotUser):
|
||||||
|
"""Bot user with API key authentication"""
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "discord_bot",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_with_api_key(
|
||||||
|
cls,
|
||||||
|
discord_users: list,
|
||||||
|
name: str,
|
||||||
|
email: str,
|
||||||
|
api_key: str | None = None,
|
||||||
|
) -> "DiscordBotUser":
|
||||||
|
bot = super().create_with_api_key(name, email, api_key)
|
||||||
|
bot.discord_users = discord_users
|
||||||
|
return bot
|
||||||
|
|
||||||
|
|
||||||
class UserSession(Base):
|
class UserSession(Base):
|
||||||
|
@ -25,7 +25,7 @@ def send_dm(user_identifier: str, message: str) -> bool:
|
|||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_dm",
|
f"{get_api_url()}/send_dm",
|
||||||
json={"user_identifier": user_identifier, "message": message},
|
json={"user": user_identifier, "message": message},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -37,6 +37,24 @@ def send_dm(user_identifier: str, message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def send_to_channel(channel_name: str, message: str) -> bool:
|
||||||
|
"""Send a DM via the Discord collector API"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{get_api_url()}/send_channel",
|
||||||
|
json={"channel_name": channel_name, "message": message},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
print("Result", result)
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to send to channel {channel_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(channel_name: str, message: str) -> bool:
|
def broadcast_message(channel_name: str, message: str) -> bool:
|
||||||
"""Send a message to a channel via the Discord collector API"""
|
"""Send a message to a channel via the Discord collector API"""
|
||||||
try:
|
try:
|
||||||
|
@ -81,3 +81,28 @@ def truncate(content: str, target_tokens: int) -> str:
|
|||||||
if len(content) > target_chars:
|
if len(content) > target_chars:
|
||||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
# bla = 1
|
||||||
|
# from memory.common.llms import *
|
||||||
|
# from memory.common.llms.tools.discord import make_discord_tools
|
||||||
|
# from memory.common.db.connection import make_session
|
||||||
|
# from memory.common.db.models import *
|
||||||
|
|
||||||
|
# model = "anthropic/claude-sonnet-4-5"
|
||||||
|
# provider = create_provider(model=model)
|
||||||
|
# with make_session() as session:
|
||||||
|
# bot = session.query(DiscordBotUser).first()
|
||||||
|
# server = session.query(DiscordServer).first()
|
||||||
|
# channel = server.channels[0]
|
||||||
|
# tools = make_discord_tools(bot, None, channel, model)
|
||||||
|
|
||||||
|
# def demo(msg: str):
|
||||||
|
# messages = [
|
||||||
|
# Message(
|
||||||
|
# role=MessageRole.USER,
|
||||||
|
# content=msg,
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
# for m in provider.stream_with_tools(messages, tools):
|
||||||
|
# print(m)
|
||||||
|
@ -333,7 +333,6 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||||
|
|
||||||
print(kwargs)
|
|
||||||
try:
|
try:
|
||||||
with self.client.messages.stream(**kwargs) as stream:
|
with self.client.messages.stream(**kwargs) as stream:
|
||||||
current_tool_use: dict[str, Any] | None = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
|
@ -599,6 +599,9 @@ class BaseLLMProvider(ABC):
|
|||||||
tool_calls=tool_calls or None,
|
tool_calls=tool_calls or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def as_messages(self, messages) -> list[Message]:
|
||||||
|
return [Message.user(text=msg) for msg in messages]
|
||||||
|
|
||||||
|
|
||||||
def create_provider(
|
def create_provider(
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
@ -150,7 +150,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
def _convert_tools(
|
def _convert_tools(
|
||||||
self, tools: list[ToolDefinition] | None
|
self, tools: list[ToolDefinition] | None
|
||||||
) -> Optional[list[dict[str, Any]]]:
|
) -> list[dict[str, Any]] | None:
|
||||||
"""
|
"""
|
||||||
Convert our tool definitions to OpenAI format.
|
Convert our tool definitions to OpenAI format.
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None,
|
system_prompt: str | None,
|
||||||
tools: Optional[list[ToolDefinition]],
|
tools: list[ToolDefinition] | None,
|
||||||
settings: LLMSettings,
|
settings: LLMSettings,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@ -270,7 +270,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
self,
|
self,
|
||||||
chunk: Any,
|
chunk: Any,
|
||||||
current_tool_call: dict[str, Any] | None,
|
current_tool_call: dict[str, Any] | None,
|
||||||
) -> tuple[list[StreamEvent], Optional[dict[str, Any]]]:
|
) -> tuple[list[StreamEvent], dict[str, Any] | None]:
|
||||||
"""
|
"""
|
||||||
Handle a single streaming chunk and return events and updated tool state.
|
Handle a single streaming chunk and return events and updated tool state.
|
||||||
|
|
||||||
@ -325,9 +325,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: str | None = None,
|
||||||
tools: Optional[list[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
settings: Optional[LLMSettings] = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response."""
|
"""Generate a non-streaming response."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
@ -374,9 +374,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
async def agenerate(
|
async def agenerate(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: str | None = None,
|
||||||
tools: Optional[list[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
settings: Optional[LLMSettings] = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response asynchronously."""
|
"""Generate a non-streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
@ -394,9 +394,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: str | None = None,
|
||||||
tools: Optional[list[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
settings: Optional[LLMSettings] = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> AsyncIterator[StreamEvent]:
|
) -> AsyncIterator[StreamEvent]:
|
||||||
"""Generate a streaming response asynchronously."""
|
"""Generate a streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
@ -406,7 +406,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
stream = await self.async_client.chat.completions.create(**kwargs)
|
stream = await self.async_client.chat.completions.create(**kwargs)
|
||||||
current_tool_call: Optional[dict[str, Any]] = None
|
current_tool_call: dict[str, Any] | None = None
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
events, current_tool_call = self._handle_stream_chunk(
|
events, current_tool_call = self._handle_stream_chunk(
|
||||||
|
231
src/memory/common/llms/tools/discord.py
Normal file
231
src/memory/common/llms/tools/discord.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
"""Discord tool for interacting with Discord."""
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal, cast
|
||||||
|
from memory.discord.messages import (
|
||||||
|
upsert_scheduled_message,
|
||||||
|
comm_channel_prompt,
|
||||||
|
previous_messages,
|
||||||
|
)
|
||||||
|
from sqlalchemy import BigInteger
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import (
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
|
BotUser,
|
||||||
|
)
|
||||||
|
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
|
||||||
|
|
||||||
|
|
||||||
|
UpdateSummaryType = Literal["server", "channel", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
def handle_update_summary_call(
|
||||||
|
type: UpdateSummaryType, item_id: BigInteger
|
||||||
|
) -> ToolHandler:
|
||||||
|
models = {
|
||||||
|
"server": DiscordServer,
|
||||||
|
"channel": DiscordChannel,
|
||||||
|
"user": DiscordUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
def handler(input: ToolInput = None) -> str:
|
||||||
|
if isinstance(input, dict):
|
||||||
|
summary = input.get("summary") or str(input)
|
||||||
|
else:
|
||||||
|
summary = str(input)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with make_session() as session:
|
||||||
|
model = models[type]
|
||||||
|
model = session.get(model, item_id)
|
||||||
|
model.summary = summary # type: ignore
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error updating summary: {e}"
|
||||||
|
return "Updated summary"
|
||||||
|
|
||||||
|
handler.__doc__ = textwrap.dedent("""
|
||||||
|
Handle a {type} summary update tool call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary: The new summary of the Discord {type}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response string
|
||||||
|
""").format(type=type)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefinition:
|
||||||
|
return ToolDefinition(
|
||||||
|
name=f"update_{type}_summary",
|
||||||
|
description=textwrap.dedent("""
|
||||||
|
Use this to update the summary of this Discord {type} that is added to your context.
|
||||||
|
|
||||||
|
This will overwrite the previous summary.
|
||||||
|
""").format(type=type),
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"summary": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"The new summary of the Discord {type}",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
function=handle_update_summary_call(type, item_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_message(
|
||||||
|
user_id: int,
|
||||||
|
user: int | None,
|
||||||
|
channel: int | None,
|
||||||
|
model: str,
|
||||||
|
message: str,
|
||||||
|
date_time: datetime,
|
||||||
|
) -> str:
|
||||||
|
with make_session() as session:
|
||||||
|
call = upsert_scheduled_message(
|
||||||
|
session,
|
||||||
|
scheduled_time=date_time,
|
||||||
|
message=message,
|
||||||
|
user_id=user_id,
|
||||||
|
model=model,
|
||||||
|
discord_user=user,
|
||||||
|
discord_channel=channel,
|
||||||
|
system_prompt=comm_channel_prompt(session, user, channel),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
return cast(str, call.id)
|
||||||
|
|
||||||
|
|
||||||
|
def make_message_scheduler(
|
||||||
|
bot: BotUser, user: int | None, channel: int | None, model: str
|
||||||
|
) -> ToolDefinition:
|
||||||
|
bot_id = cast(int, bot.id)
|
||||||
|
if user:
|
||||||
|
channel_type = "from your chat with this user"
|
||||||
|
elif channel:
|
||||||
|
channel_type = "in this channel"
|
||||||
|
else:
|
||||||
|
raise ValueError("Either user or channel must be provided")
|
||||||
|
|
||||||
|
def handler(input: ToolInput) -> str:
|
||||||
|
if not isinstance(input, dict):
|
||||||
|
raise ValueError("Input must be a dictionary")
|
||||||
|
|
||||||
|
try:
|
||||||
|
time = datetime.fromisoformat(input["date_time"])
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Invalid date time format")
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("Date time is required")
|
||||||
|
|
||||||
|
return schedule_message(bot_id, user, channel, model, input["message"], time)
|
||||||
|
|
||||||
|
return ToolDefinition(
|
||||||
|
name="schedule_message",
|
||||||
|
description=textwrap.dedent("""
|
||||||
|
Use this to schedule a message to be sent to yourself.
|
||||||
|
|
||||||
|
At the specified date and time, your message will be sent to you, along with the most
|
||||||
|
recent messages {channel_type}.
|
||||||
|
|
||||||
|
Normally you will be called with any incoming messages. But sometimes you might want to be
|
||||||
|
able to trigger a call to yourself at a specific time, rather than waiting for the next call.
|
||||||
|
This tool allows you to do that.
|
||||||
|
So for example, if you were chatting with a Discord user, and you ask a question which needs to
|
||||||
|
be answered right away, you can use this tool to schedule a check in 5 minutes time, to remind
|
||||||
|
the user to answer the question.
|
||||||
|
""").format(channel_type=channel_type),
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to send",
|
||||||
|
},
|
||||||
|
"date_time": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The date and time to send the message in ISO format (e.g., 2025-01-01T00:00:00Z)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
function=handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition:
|
||||||
|
if user:
|
||||||
|
channel_type = "from your chat with this user"
|
||||||
|
elif channel:
|
||||||
|
channel_type = "in this channel"
|
||||||
|
else:
|
||||||
|
raise ValueError("Either user or channel must be provided")
|
||||||
|
|
||||||
|
def handler(input: ToolInput) -> str:
|
||||||
|
if not isinstance(input, dict):
|
||||||
|
raise ValueError("Input must be a dictionary")
|
||||||
|
try:
|
||||||
|
max_messages = int(input.get("max_messages") or 10)
|
||||||
|
offset = int(input.get("offset") or 0)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Max messages and offset must be integers")
|
||||||
|
|
||||||
|
if max_messages <= 0:
|
||||||
|
raise ValueError("Max messages must be greater than 0")
|
||||||
|
if offset < 0:
|
||||||
|
raise ValueError("Offset must be greater than or equal to 0")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
messages = previous_messages(session, user, channel, max_messages, offset)
|
||||||
|
return "\n\n".join([msg.title for msg in messages])
|
||||||
|
|
||||||
|
return ToolDefinition(
|
||||||
|
name="previous_messages",
|
||||||
|
description=f"Get the previous N messages {channel_type}.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"max_messages": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The maximum number of messages to return",
|
||||||
|
"default": 10,
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number of messages to offset the result by",
|
||||||
|
"default": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
function=handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_discord_tools(
|
||||||
|
bot: BotUser,
|
||||||
|
author: DiscordUser | None,
|
||||||
|
channel: DiscordChannel | None,
|
||||||
|
model: str,
|
||||||
|
) -> dict[str, ToolDefinition]:
|
||||||
|
author_id = author and author.id
|
||||||
|
channel_id = channel and channel.id
|
||||||
|
tools = [
|
||||||
|
make_message_scheduler(bot, author_id, channel_id, model),
|
||||||
|
make_prev_messages_tool(author_id, channel_id),
|
||||||
|
make_summary_tool("channel", channel_id),
|
||||||
|
]
|
||||||
|
if author:
|
||||||
|
tools += [make_summary_tool("user", author_id)]
|
||||||
|
if channel and channel.server:
|
||||||
|
tools += [
|
||||||
|
make_summary_tool("server", cast(BigInteger, channel.server_id)),
|
||||||
|
]
|
||||||
|
return {tool.name: tool for tool in tools}
|
@ -133,7 +133,7 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
|||||||
ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip()
|
ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip()
|
||||||
else:
|
else:
|
||||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
|
||||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||||
|
|
||||||
@ -173,11 +173,14 @@ DISCORD_NOTIFICATIONS_ENABLED = bool(
|
|||||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||||
)
|
)
|
||||||
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
||||||
|
DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5")
|
||||||
|
DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10))
|
||||||
|
|
||||||
|
|
||||||
# Discord collector settings
|
# Discord collector settings
|
||||||
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
|
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
|
||||||
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
|
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
|
||||||
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
|
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
|
||||||
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8000))
|
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
|
||||||
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
|
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
|
||||||
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
||||||
|
@ -13,7 +13,7 @@ from sqlalchemy.orm import Session, scoped_session
|
|||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.sources import (
|
from memory.common.db.models import (
|
||||||
DiscordServer,
|
DiscordServer,
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
@ -227,6 +227,7 @@ class MessageCollector(commands.Bot):
|
|||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
channel_id=message.channel.id,
|
channel_id=message.channel.id,
|
||||||
author_id=message.author.id,
|
author_id=message.author.id,
|
||||||
|
recipient_id=self.user and self.user.id,
|
||||||
server_id=message.guild.id if message.guild else None,
|
server_id=message.guild.id if message.guild else None,
|
||||||
content=message.content or "",
|
content=message.content or "",
|
||||||
sent_at=message.created_at.isoformat(),
|
sent_at=message.created_at.isoformat(),
|
||||||
|
205
src/memory/discord/messages.py
Normal file
205
src/memory/discord/messages.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
import logging
|
||||||
|
import textwrap
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
from memory.common.db.models import (
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
|
ScheduledLLMCall,
|
||||||
|
DiscordMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DiscordEntity = DiscordChannel | DiscordUser | str | int | None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_discord_user(
|
||||||
|
session: Session | scoped_session, entity: DiscordEntity
|
||||||
|
) -> DiscordUser | None:
|
||||||
|
if not entity:
|
||||||
|
return None
|
||||||
|
if isinstance(entity, DiscordUser):
|
||||||
|
return entity
|
||||||
|
if isinstance(entity, int):
|
||||||
|
return session.get(DiscordUser, entity)
|
||||||
|
|
||||||
|
entity = session.query(DiscordUser).filter(DiscordUser.username == entity).first()
|
||||||
|
if not entity:
|
||||||
|
entity = DiscordUser(id=entity, username=entity)
|
||||||
|
session.add(entity)
|
||||||
|
return entity
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_discord_channel(
|
||||||
|
session: Session | scoped_session, entity: DiscordEntity
|
||||||
|
) -> DiscordChannel | None:
|
||||||
|
if not entity:
|
||||||
|
return None
|
||||||
|
if isinstance(entity, DiscordChannel):
|
||||||
|
return entity
|
||||||
|
if isinstance(entity, int):
|
||||||
|
return session.get(DiscordChannel, entity)
|
||||||
|
|
||||||
|
return session.query(DiscordChannel).filter(DiscordChannel.name == entity).first()
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_discord_message(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
scheduled_time: datetime,
|
||||||
|
message: str,
|
||||||
|
user_id: int,
|
||||||
|
model: str | None = None,
|
||||||
|
topic: str | None = None,
|
||||||
|
discord_user: DiscordEntity = None,
|
||||||
|
discord_channel: DiscordEntity = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> ScheduledLLMCall:
|
||||||
|
discord_user = resolve_discord_user(session, discord_user)
|
||||||
|
discord_channel = resolve_discord_channel(session, discord_channel)
|
||||||
|
if not discord_user and not discord_channel:
|
||||||
|
raise ValueError("Either discord_user or discord_channel must be provided")
|
||||||
|
|
||||||
|
# Validate that the scheduled time is in the future
|
||||||
|
# Compare with naive datetime since we store naive in the database
|
||||||
|
current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
if scheduled_time.replace(tzinfo=None) <= current_time_naive:
|
||||||
|
raise ValueError("Scheduled time must be in the future")
|
||||||
|
|
||||||
|
# Create the scheduled call
|
||||||
|
scheduled_call = ScheduledLLMCall(
|
||||||
|
user_id=user_id,
|
||||||
|
scheduled_time=scheduled_time,
|
||||||
|
message=message,
|
||||||
|
topic=topic,
|
||||||
|
model=model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
discord_channel=resolve_discord_channel(session, discord_channel),
|
||||||
|
discord_user=resolve_discord_user(session, discord_user),
|
||||||
|
data=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(scheduled_call)
|
||||||
|
return scheduled_call
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_scheduled_message(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
scheduled_time: datetime,
|
||||||
|
message: str,
|
||||||
|
user_id: int,
|
||||||
|
model: str | None = None,
|
||||||
|
topic: str | None = None,
|
||||||
|
discord_user: DiscordEntity = None,
|
||||||
|
discord_channel: DiscordEntity = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> ScheduledLLMCall:
|
||||||
|
discord_user = resolve_discord_user(session, discord_user)
|
||||||
|
discord_channel = resolve_discord_channel(session, discord_channel)
|
||||||
|
prev_call = (
|
||||||
|
session.query(ScheduledLLMCall)
|
||||||
|
.filter(
|
||||||
|
ScheduledLLMCall.user_id == user_id,
|
||||||
|
ScheduledLLMCall.model == model,
|
||||||
|
ScheduledLLMCall.discord_user_id == (discord_user and discord_user.id),
|
||||||
|
ScheduledLLMCall.discord_channel_id
|
||||||
|
== (discord_channel and discord_channel.id),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
|
||||||
|
print(f"naive_scheduled_time: {naive_scheduled_time}")
|
||||||
|
print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
|
||||||
|
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
|
||||||
|
prev_call.status = "cancelled" # type: ignore
|
||||||
|
|
||||||
|
return schedule_discord_message(
|
||||||
|
session,
|
||||||
|
scheduled_time,
|
||||||
|
message,
|
||||||
|
user_id=user_id,
|
||||||
|
model=model,
|
||||||
|
topic=topic,
|
||||||
|
discord_user=discord_user,
|
||||||
|
discord_channel=discord_channel,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def previous_messages(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
user_id: int | None,
|
||||||
|
channel_id: int | None,
|
||||||
|
max_messages: int = 10,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[DiscordMessage]:
|
||||||
|
messages = session.query(DiscordMessage)
|
||||||
|
if user_id:
|
||||||
|
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||||
|
if channel_id:
|
||||||
|
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||||
|
return list(
|
||||||
|
reversed(
|
||||||
|
messages.order_by(DiscordMessage.sent_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(max_messages)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def comm_channel_prompt(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
user: DiscordEntity,
|
||||||
|
channel: DiscordEntity,
|
||||||
|
max_messages: int = 10,
|
||||||
|
) -> str:
|
||||||
|
user = resolve_discord_user(session, user)
|
||||||
|
channel = resolve_discord_channel(session, channel)
|
||||||
|
|
||||||
|
messages = previous_messages(
|
||||||
|
session, user and user.id, channel and channel.id, max_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
server_context = ""
|
||||||
|
if channel and channel.server:
|
||||||
|
server_context = textwrap.dedent("""
|
||||||
|
Here are your previous notes on the server:
|
||||||
|
<server_context>
|
||||||
|
{summary}
|
||||||
|
</server_context>
|
||||||
|
""").format(summary=channel.server.summary)
|
||||||
|
if channel:
|
||||||
|
server_context += textwrap.dedent("""
|
||||||
|
Here are your previous notes on the channel:
|
||||||
|
<channel_context>
|
||||||
|
{summary}
|
||||||
|
</channel_context>
|
||||||
|
""").format(summary=channel.summary)
|
||||||
|
if messages:
|
||||||
|
server_context += textwrap.dedent("""
|
||||||
|
Here are your previous notes on the users:
|
||||||
|
<user_notes>
|
||||||
|
{users}
|
||||||
|
</user_notes>
|
||||||
|
""").format(
|
||||||
|
users="\n".join({msg.from_user.as_xml() for msg in messages}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return textwrap.dedent("""
|
||||||
|
You are a bot communicating on Discord.
|
||||||
|
|
||||||
|
{server_context}
|
||||||
|
|
||||||
|
Whenever something worth remembering is said, you should add a note to the appropriate context - use
|
||||||
|
this to track your understanding of the conversation and those taking part in it.
|
||||||
|
|
||||||
|
You will be given the last {max_messages} messages in the conversation.
|
||||||
|
Please react to them appropriately. You can return an empty response if you don't have anything to say.
|
||||||
|
""").format(server_context=server_context, max_messages=max_messages)
|
@ -4,25 +4,31 @@ Celery tasks for Discord message processing.
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from memory.common.celery_app import app
|
from sqlalchemy import exc as sqlalchemy_exc
|
||||||
from memory.common.db.connection import make_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
from memory.common.db.models import DiscordMessage, DiscordUser
|
|
||||||
from memory.workers.tasks.content_processing import (
|
from memory.common import discord, settings
|
||||||
safe_task_execution,
|
|
||||||
check_content_exists,
|
|
||||||
create_task_result,
|
|
||||||
process_content_item,
|
|
||||||
)
|
|
||||||
from memory.common.celery_app import (
|
from memory.common.celery_app import (
|
||||||
ADD_DISCORD_MESSAGE,
|
ADD_DISCORD_MESSAGE,
|
||||||
EDIT_DISCORD_MESSAGE,
|
EDIT_DISCORD_MESSAGE,
|
||||||
PROCESS_DISCORD_MESSAGE,
|
PROCESS_DISCORD_MESSAGE,
|
||||||
|
app,
|
||||||
|
)
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import DiscordMessage, DiscordUser
|
||||||
|
from memory.common.llms.base import create_provider
|
||||||
|
from memory.common.llms.tools.discord import make_discord_tools
|
||||||
|
from memory.discord.messages import comm_channel_prompt, previous_messages
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
check_content_exists,
|
||||||
|
create_task_result,
|
||||||
|
process_content_item,
|
||||||
|
safe_task_execution,
|
||||||
)
|
)
|
||||||
from memory.common import settings
|
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -32,7 +38,7 @@ def get_prev(
|
|||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
prev = (
|
prev = (
|
||||||
session.query(DiscordUser.username, DiscordMessage.content)
|
session.query(DiscordUser.username, DiscordMessage.content)
|
||||||
.join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
|
.join(DiscordUser, DiscordMessage.from_id == DiscordUser.id)
|
||||||
.filter(
|
.filter(
|
||||||
DiscordMessage.channel_id == channel_id,
|
DiscordMessage.channel_id == channel_id,
|
||||||
DiscordMessage.sent_at < sent_at,
|
DiscordMessage.sent_at < sent_at,
|
||||||
@ -45,20 +51,54 @@ def get_prev(
|
|||||||
|
|
||||||
|
|
||||||
def should_process(message: DiscordMessage) -> bool:
|
def should_process(message: DiscordMessage) -> bool:
|
||||||
return (
|
if not (
|
||||||
settings.DISCORD_PROCESS_MESSAGES
|
settings.DISCORD_PROCESS_MESSAGES
|
||||||
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
||||||
and not (
|
and not (
|
||||||
(message.server and message.server.ignore_messages)
|
(message.server and message.server.ignore_messages)
|
||||||
or (message.channel and message.channel.ignore_messages)
|
or (message.channel and message.channel.ignore_messages)
|
||||||
or (message.discord_user and message.discord_user.ignore_messages)
|
or (message.from_user and message.from_user.ignore_messages)
|
||||||
)
|
)
|
||||||
)
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
provider = create_provider(model=settings.SUMMARIZER_MODEL)
|
||||||
|
with make_session() as session:
|
||||||
|
system_prompt = comm_channel_prompt(
|
||||||
|
session, message.recipient_user, message.channel
|
||||||
|
)
|
||||||
|
messages = previous_messages(
|
||||||
|
session,
|
||||||
|
message.recipient_user and message.recipient_user.id,
|
||||||
|
message.channel and message.channel.id,
|
||||||
|
max_messages=10,
|
||||||
|
)
|
||||||
|
msg = textwrap.dedent("""
|
||||||
|
Should you continue the conversation with the user?
|
||||||
|
Please return "yes" or "no" as:
|
||||||
|
|
||||||
|
<response>yes</response>
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
<response>no</response>
|
||||||
|
|
||||||
|
""")
|
||||||
|
response = provider.generate(
|
||||||
|
messages=provider.as_messages([m.title for m in messages] + [msg]),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
return "<response>yes</response>" in "".join(response.lower().split())
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||||
@safe_task_execution
|
@safe_task_execution
|
||||||
def process_discord_message(message_id: int) -> dict[str, Any]:
|
def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process a Discord message.
|
||||||
|
|
||||||
|
This task is queued by the Discord collector when messages are received.
|
||||||
|
"""
|
||||||
logger.info(f"Processing Discord message {message_id}")
|
logger.info(f"Processing Discord message {message_id}")
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
@ -71,7 +111,39 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
print("Processing message", discord_message.id, discord_message.content)
|
tools = make_discord_tools(
|
||||||
|
discord_message.recipient_user,
|
||||||
|
discord_message.from_user,
|
||||||
|
discord_message.channel,
|
||||||
|
model=settings.DISCORD_MODEL,
|
||||||
|
)
|
||||||
|
tools = {
|
||||||
|
name: tool
|
||||||
|
for name, tool in tools.items()
|
||||||
|
if discord_message.tool_allowed(name)
|
||||||
|
}
|
||||||
|
system_prompt = comm_channel_prompt(
|
||||||
|
session, discord_message.recipient_user, discord_message.channel
|
||||||
|
)
|
||||||
|
messages = previous_messages(
|
||||||
|
session,
|
||||||
|
discord_message.recipient_user and discord_message.recipient_user.id,
|
||||||
|
discord_message.channel and discord_message.channel.id,
|
||||||
|
max_messages=10,
|
||||||
|
)
|
||||||
|
provider = create_provider(model=settings.DISCORD_MODEL)
|
||||||
|
turn = provider.run_with_tools(
|
||||||
|
messages=provider.as_messages([m.title for m in messages]),
|
||||||
|
tools=tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
|
)
|
||||||
|
if not turn.response:
|
||||||
|
pass
|
||||||
|
elif discord_message.channel.server:
|
||||||
|
discord.send_to_channel(discord_message.channel.name, turn.response)
|
||||||
|
else:
|
||||||
|
discord.send_dm(discord_message.from_user.username, turn.response)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "processed",
|
||||||
@ -88,6 +160,7 @@ def add_discord_message(
|
|||||||
content: str,
|
content: str,
|
||||||
sent_at: str,
|
sent_at: str,
|
||||||
server_id: int | None = None,
|
server_id: int | None = None,
|
||||||
|
recipient_id: int | None = None,
|
||||||
message_reference_id: int | None = None,
|
message_reference_id: int | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -108,7 +181,8 @@ def add_discord_message(
|
|||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
sent_at=sent_at_dt,
|
sent_at=sent_at_dt,
|
||||||
server_id=server_id,
|
server_id=server_id,
|
||||||
discord_user_id=author_id,
|
from_id=author_id,
|
||||||
|
recipient_id=recipient_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
message_type="reply" if message_reference_id else "default",
|
message_type="reply" if message_reference_id else "default",
|
||||||
reply_to_message_id=message_reference_id,
|
reply_to_message_id=message_reference_id,
|
||||||
@ -125,7 +199,15 @@ def add_discord_message(
|
|||||||
if channel_id:
|
if channel_id:
|
||||||
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
|
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
|
||||||
|
|
||||||
result = process_content_item(discord_message, session)
|
try:
|
||||||
|
result = process_content_item(discord_message, session)
|
||||||
|
except sqlalchemy_exc.IntegrityError as e:
|
||||||
|
logger.error(f"Integrity error adding Discord message {message_id}: {e}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Integrity error",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
if should_process(discord_message):
|
if should_process(discord_message):
|
||||||
process_discord_message.delay(discord_message.id)
|
process_discord_message.delay(discord_message.id)
|
||||||
|
|
||||||
|
@ -37,12 +37,12 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
|||||||
if len(message) > 1900: # Leave some buffer
|
if len(message) > 1900: # Leave some buffer
|
||||||
message = message[:1900] + "\n\n... (response truncated)"
|
message = message[:1900] + "\n\n... (response truncated)"
|
||||||
|
|
||||||
if discord_user := cast(str, scheduled_call.discord_user):
|
if discord_user := scheduled_call.discord_user:
|
||||||
logger.info(f"Sending DM to {discord_user}: {message}")
|
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
||||||
discord.send_dm(discord_user, message)
|
discord.send_dm(discord_user.username, message)
|
||||||
elif discord_channel := cast(str, scheduled_call.discord_channel):
|
elif discord_channel := scheduled_call.discord_channel:
|
||||||
logger.info(f"Broadcasting message to {discord_channel}: {message}")
|
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
||||||
discord.broadcast_message(discord_channel, message)
|
discord.broadcast_message(discord_channel.name, message)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
||||||
@ -62,11 +62,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
# Fetch the scheduled call
|
# Fetch the scheduled call
|
||||||
scheduled_call = (
|
scheduled_call = session.query(ScheduledLLMCall).get(scheduled_call_id)
|
||||||
session.query(ScheduledLLMCall)
|
|
||||||
.filter(ScheduledLLMCall.id == scheduled_call_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not scheduled_call:
|
if not scheduled_call:
|
||||||
logger.error(f"Scheduled call {scheduled_call_id} not found")
|
logger.error(f"Scheduled call {scheduled_call_id} not found")
|
||||||
|
@ -254,17 +254,59 @@ def mock_openai_client():
|
|||||||
with patch.object(openai, "OpenAI", autospec=True) as mock_client:
|
with patch.object(openai, "OpenAI", autospec=True) as mock_client:
|
||||||
client = mock_client()
|
client = mock_client()
|
||||||
client.chat = Mock()
|
client.chat = Mock()
|
||||||
|
|
||||||
|
# Mock non-streaming response
|
||||||
client.chat.completions.create = Mock(
|
client.chat.completions.create = Mock(
|
||||||
return_value=Mock(
|
return_value=Mock(
|
||||||
choices=[
|
choices=[
|
||||||
Mock(
|
Mock(
|
||||||
message=Mock(
|
message=Mock(
|
||||||
content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
|
content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
|
||||||
)
|
),
|
||||||
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Store original side_effect for potential override
|
||||||
|
def streaming_response(*args, **kwargs):
|
||||||
|
if kwargs.get("stream"):
|
||||||
|
# Return mock streaming chunks
|
||||||
|
return iter(
|
||||||
|
[
|
||||||
|
Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(
|
||||||
|
delta=Mock(content="test", tool_calls=None),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(
|
||||||
|
delta=Mock(content=" response", tool_calls=None),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Return non-streaming response
|
||||||
|
return Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(
|
||||||
|
message=Mock(
|
||||||
|
content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
|
||||||
|
),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
client.chat.completions.create.side_effect = streaming_response
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
0
tests/memory/common/llms/__init__.py
Normal file
0
tests/memory/common/llms/__init__.py
Normal file
552
tests/memory/common/llms/test_anthropic_event_parsing.py
Normal file
552
tests/memory/common/llms/test_anthropic_event_parsing.py
Normal file
@ -0,0 +1,552 @@
|
|||||||
|
"""Comprehensive tests for Anthropic stream event parsing."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
|
from memory.common.llms.base import StreamEvent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider():
|
||||||
|
return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229")
|
||||||
|
|
||||||
|
|
||||||
|
# Content Block Start Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"block_type,block_attrs,expected_tool_use",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"tool_use",
|
||||||
|
{"id": "tool-1", "name": "search", "input": {}},
|
||||||
|
{
|
||||||
|
"id": "tool-1",
|
||||||
|
"name": "search",
|
||||||
|
"input": {},
|
||||||
|
"server_name": None,
|
||||||
|
"is_server_call": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"mcp_tool_use",
|
||||||
|
{
|
||||||
|
"id": "mcp-1",
|
||||||
|
"name": "mcp_search",
|
||||||
|
"input": {},
|
||||||
|
"server_name": "mcp-server",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "mcp-1",
|
||||||
|
"name": "mcp_search",
|
||||||
|
"input": {},
|
||||||
|
"server_name": "mcp-server",
|
||||||
|
"is_server_call": True,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"server_tool_use",
|
||||||
|
{
|
||||||
|
"id": "srv-1",
|
||||||
|
"name": "server_action",
|
||||||
|
"input": {},
|
||||||
|
"server_name": "custom-server",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "srv-1",
|
||||||
|
"name": "server_action",
|
||||||
|
"input": {},
|
||||||
|
"server_name": "custom-server",
|
||||||
|
"is_server_call": True,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content_block_start_tool_types(
|
||||||
|
provider, block_type, block_attrs, expected_tool_use
|
||||||
|
):
|
||||||
|
"""Different tool types should be tracked correctly."""
|
||||||
|
block = Mock(spec=["type"] + list(block_attrs.keys()))
|
||||||
|
block.type = block_type
|
||||||
|
for key, value in block_attrs.items():
|
||||||
|
setattr(block, key, value)
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "content_block"])
|
||||||
|
event.type = "content_block_start"
|
||||||
|
event.content_block = block
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
assert tool_use == expected_tool_use
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_block_start_tool_without_input(provider):
|
||||||
|
"""Tool use without input field should initialize as empty string."""
|
||||||
|
block = Mock(spec=["type", "id", "name"])
|
||||||
|
block.type = "tool_use"
|
||||||
|
block.id = "tool-2"
|
||||||
|
block.name = "calculate"
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "content_block"])
|
||||||
|
event.type = "content_block_start"
|
||||||
|
event.content_block = block
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert tool_use["input"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_block_start_tool_result(provider):
|
||||||
|
"""Tool result blocks should emit tool_result event."""
|
||||||
|
block = Mock(spec=["tool_use_id", "content"])
|
||||||
|
block.tool_use_id = "tool-1"
|
||||||
|
block.content = "Result content"
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "content_block"])
|
||||||
|
event.type = "content_block_start"
|
||||||
|
event.content_block = block
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is not None
|
||||||
|
assert stream_event.type == "tool_result"
|
||||||
|
assert stream_event.data == {"id": "tool-1", "result": "Result content"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"has_content_block,block_type",
|
||||||
|
[
|
||||||
|
(False, None),
|
||||||
|
(True, "unknown_type"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content_block_start_ignored_cases(provider, has_content_block, block_type):
|
||||||
|
"""Events without content_block or with unknown types should be ignored."""
|
||||||
|
event = Mock(spec=["type", "content_block"] if has_content_block else ["type"])
|
||||||
|
event.type = "content_block_start"
|
||||||
|
|
||||||
|
if has_content_block:
|
||||||
|
block = Mock(spec=["type"])
|
||||||
|
block.type = block_type
|
||||||
|
event.content_block = block
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
# Content Block Delta Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"delta_type,delta_attr,attr_value,expected_type,expected_data",
|
||||||
|
[
|
||||||
|
("text_delta", "text", "Hello world", "text", "Hello world"),
|
||||||
|
("text_delta", "text", "", "text", ""),
|
||||||
|
(
|
||||||
|
"thinking_delta",
|
||||||
|
"thinking",
|
||||||
|
"Let me think...",
|
||||||
|
"thinking",
|
||||||
|
"Let me think...",
|
||||||
|
),
|
||||||
|
("signature_delta", "signature", "sig-12345", "thinking", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content_block_delta_types(
|
||||||
|
provider, delta_type, delta_attr, attr_value, expected_type, expected_data
|
||||||
|
):
|
||||||
|
"""Different delta types should emit appropriate events."""
|
||||||
|
delta = Mock(spec=["type", delta_attr])
|
||||||
|
delta.type = delta_type
|
||||||
|
setattr(delta, delta_attr, attr_value)
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "delta"])
|
||||||
|
event.type = "content_block_delta"
|
||||||
|
event.delta = delta
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event.type == expected_type
|
||||||
|
if expected_type == "thinking" and delta_type == "signature_delta":
|
||||||
|
assert stream_event.signature == attr_value
|
||||||
|
else:
|
||||||
|
assert stream_event.data == expected_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"current_tool,partial_json,expected_input",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{"id": "t1", "name": "search", "input": '{"query": "'},
|
||||||
|
'test"}',
|
||||||
|
'{"query": "test"}',
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": "t1", "name": "search", "input": '{"'},
|
||||||
|
'key": "value"}',
|
||||||
|
'{"key": "value"}',
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": "t1", "name": "search", "input": ""},
|
||||||
|
'{"query": "test"}',
|
||||||
|
'{"query": "test"}',
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content_block_delta_input_json_accumulation(
|
||||||
|
provider, current_tool, partial_json, expected_input
|
||||||
|
):
|
||||||
|
"""JSON delta should accumulate to tool input."""
|
||||||
|
delta = Mock(spec=["type", "partial_json"])
|
||||||
|
delta.type = "input_json_delta"
|
||||||
|
delta.partial_json = partial_json
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "delta"])
|
||||||
|
event.type = "content_block_delta"
|
||||||
|
event.delta = delta
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
assert tool_use["input"] == expected_input
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_block_delta_input_json_without_tool(provider):
|
||||||
|
"""JSON delta without tool context should return None."""
|
||||||
|
delta = Mock(spec=["type", "partial_json"])
|
||||||
|
delta.type = "input_json_delta"
|
||||||
|
delta.partial_json = '{"key": "value"}'
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "delta"])
|
||||||
|
event.type = "content_block_delta"
|
||||||
|
event.delta = delta
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_block_delta_input_json_with_dict_input(provider):
|
||||||
|
"""JSON delta shouldn't modify if input is already a dict."""
|
||||||
|
current_tool = {"id": "t1", "name": "search", "input": {"query": "test"}}
|
||||||
|
|
||||||
|
delta = Mock(spec=["type", "partial_json"])
|
||||||
|
delta.type = "input_json_delta"
|
||||||
|
delta.partial_json = ', "extra": "data"'
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "delta"])
|
||||||
|
event.type = "content_block_delta"
|
||||||
|
event.delta = delta
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
|
||||||
|
|
||||||
|
assert tool_use["input"] == {"query": "test"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"has_delta,delta_type",
|
||||||
|
[
|
||||||
|
(False, None),
|
||||||
|
(True, "unknown_delta"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content_block_delta_ignored_cases(provider, has_delta, delta_type):
|
||||||
|
"""Events without delta or with unknown types should be ignored."""
|
||||||
|
event = Mock(spec=["type", "delta"] if has_delta else ["type"])
|
||||||
|
event.type = "content_block_delta"
|
||||||
|
|
||||||
|
if has_delta:
|
||||||
|
delta = Mock(spec=["type"])
|
||||||
|
delta.type = delta_type
|
||||||
|
event.delta = delta
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
|
||||||
|
|
||||||
|
# Content Block Stop Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_value,has_content_block,expected_input",
|
||||||
|
[
|
||||||
|
("", False, {}),
|
||||||
|
(" \n\t ", False, {}),
|
||||||
|
('{"invalid": json}', False, {}),
|
||||||
|
('{"query": "test", "limit": 10}', False, {"query": "test", "limit": 10}),
|
||||||
|
(
|
||||||
|
'{"filters": {"type": "user", "status": ["active", "pending"]}, "limit": 100}',
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"filters": {"type": "user", "status": ["active", "pending"]},
|
||||||
|
"limit": 100,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
("", True, {"query": "test"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content_block_stop_tool_finalization(
|
||||||
|
provider, input_value, has_content_block, expected_input
|
||||||
|
):
|
||||||
|
"""Tool stop should parse or use provided input correctly."""
|
||||||
|
current_tool = {"id": "t1", "name": "search", "input": input_value}
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "content_block"] if has_content_block else ["type"])
|
||||||
|
event.type = "content_block_stop"
|
||||||
|
|
||||||
|
if has_content_block:
|
||||||
|
block = Mock(spec=["input"])
|
||||||
|
block.input = {"query": "test"}
|
||||||
|
event.content_block = block
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
|
||||||
|
|
||||||
|
assert stream_event.type == "tool_use"
|
||||||
|
assert stream_event.data["input"] == expected_input
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_block_stop_with_server_info(provider):
|
||||||
|
"""Server tool info should be included in final event."""
|
||||||
|
current_tool = {
|
||||||
|
"id": "t1",
|
||||||
|
"name": "mcp_search",
|
||||||
|
"input": '{"q": "test"}',
|
||||||
|
"server_name": "mcp-server",
|
||||||
|
"is_server_call": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
event = Mock(spec=["type"])
|
||||||
|
event.type = "content_block_stop"
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
|
||||||
|
|
||||||
|
assert stream_event.data["server_name"] == "mcp-server"
|
||||||
|
assert stream_event.data["is_server_call"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_block_stop_without_tool(provider):
|
||||||
|
"""Stop without current tool should return None."""
|
||||||
|
event = Mock(spec=["type"])
|
||||||
|
event.type = "content_block_stop"
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
# Message Delta Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_delta_max_tokens(provider):
|
||||||
|
"""Max tokens stop reason should emit error."""
|
||||||
|
delta = Mock(spec=["stop_reason"])
|
||||||
|
delta.stop_reason = "max_tokens"
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "delta"])
|
||||||
|
event.type = "message_delta"
|
||||||
|
event.delta = delta
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event.type == "error"
|
||||||
|
assert "Max tokens" in stream_event.data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stop_reason", ["end_turn", "stop_sequence", None])
|
||||||
|
def test_message_delta_other_stop_reasons(provider, stop_reason):
|
||||||
|
"""Other stop reasons should not emit error."""
|
||||||
|
delta = Mock(spec=["stop_reason"])
|
||||||
|
delta.stop_reason = stop_reason
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "delta"])
|
||||||
|
event.type = "message_delta"
|
||||||
|
event.delta = delta
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_delta_token_usage(provider):
|
||||||
|
"""Token usage should be logged but not emitted."""
|
||||||
|
usage = Mock(
|
||||||
|
spec=[
|
||||||
|
"input_tokens",
|
||||||
|
"output_tokens",
|
||||||
|
"cache_creation_input_tokens",
|
||||||
|
"cache_read_input_tokens",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
usage.input_tokens = 100
|
||||||
|
usage.output_tokens = 50
|
||||||
|
usage.cache_creation_input_tokens = 10
|
||||||
|
usage.cache_read_input_tokens = 20
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "usage"])
|
||||||
|
event.type = "message_delta"
|
||||||
|
event.usage = usage
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_delta_empty(provider):
|
||||||
|
"""Message delta without delta or usage should return None."""
|
||||||
|
event = Mock(spec=["type"])
|
||||||
|
event.type = "message_delta"
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
|
||||||
|
|
||||||
|
# Message Stop Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"current_tool",
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
{"id": "t1", "name": "search", "input": '{"incomplete'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_message_stop(provider, current_tool):
|
||||||
|
"""Message stop should emit done regardless of incomplete tools."""
|
||||||
|
event = Mock(spec=["type"])
|
||||||
|
event.type = "message_stop"
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
|
||||||
|
|
||||||
|
assert stream_event.type == "done"
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
# Error Handling Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"has_error,error_value,expected_message",
|
||||||
|
[
|
||||||
|
(True, "API rate limit exceeded", "rate limit"),
|
||||||
|
(False, None, "Unknown error"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_error_events(provider, has_error, error_value, expected_message):
|
||||||
|
"""Error events should emit error StreamEvent."""
|
||||||
|
event = Mock(spec=["type", "error"] if has_error else ["type"])
|
||||||
|
event.type = "error"
|
||||||
|
if has_error:
|
||||||
|
event.error = error_value
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event.type == "error"
|
||||||
|
assert expected_message in stream_event.data
|
||||||
|
|
||||||
|
|
||||||
|
# Unknown Event Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event_type",
|
||||||
|
["message_start", "future_event_type", None],
|
||||||
|
)
|
||||||
|
def test_unknown_or_ignored_events(provider, event_type):
|
||||||
|
"""Unknown event types should be logged but not fail."""
|
||||||
|
if event_type is None:
|
||||||
|
event = Mock(spec=[])
|
||||||
|
else:
|
||||||
|
event = Mock(spec=["type"])
|
||||||
|
event.type = event_type
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
|
||||||
|
|
||||||
|
# State Transition Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_tool_call_sequence(provider):
|
||||||
|
"""Simulate a complete tool call from start to finish."""
|
||||||
|
# Start
|
||||||
|
block = Mock(spec=["type", "id", "name", "input"])
|
||||||
|
block.type = "tool_use"
|
||||||
|
block.id = "tool-1"
|
||||||
|
block.name = "search"
|
||||||
|
block.input = None
|
||||||
|
|
||||||
|
event1 = Mock(spec=["type", "content_block"])
|
||||||
|
event1.type = "content_block_start"
|
||||||
|
event1.content_block = block
|
||||||
|
|
||||||
|
_, tool_use = provider._handle_stream_event(event1, None)
|
||||||
|
assert tool_use["input"] == ""
|
||||||
|
|
||||||
|
# Delta 1
|
||||||
|
delta1 = Mock(spec=["type", "partial_json"])
|
||||||
|
delta1.type = "input_json_delta"
|
||||||
|
delta1.partial_json = '{"query":'
|
||||||
|
|
||||||
|
event2 = Mock(spec=["type", "delta"])
|
||||||
|
event2.type = "content_block_delta"
|
||||||
|
event2.delta = delta1
|
||||||
|
|
||||||
|
_, tool_use = provider._handle_stream_event(event2, tool_use)
|
||||||
|
assert tool_use["input"] == '{"query":'
|
||||||
|
|
||||||
|
# Delta 2
|
||||||
|
delta2 = Mock(spec=["type", "partial_json"])
|
||||||
|
delta2.type = "input_json_delta"
|
||||||
|
delta2.partial_json = ' "test"}'
|
||||||
|
|
||||||
|
event3 = Mock(spec=["type", "delta"])
|
||||||
|
event3.type = "content_block_delta"
|
||||||
|
event3.delta = delta2
|
||||||
|
|
||||||
|
_, tool_use = provider._handle_stream_event(event3, tool_use)
|
||||||
|
assert tool_use["input"] == '{"query": "test"}'
|
||||||
|
|
||||||
|
# Stop
|
||||||
|
event4 = Mock(spec=["type"])
|
||||||
|
event4.type = "content_block_stop"
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event4, tool_use)
|
||||||
|
|
||||||
|
assert stream_event.type == "tool_use"
|
||||||
|
assert stream_event.data["input"] == {"query": "test"}
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_and_thinking_mixed(provider):
|
||||||
|
"""Text and thinking deltas should be handled independently."""
|
||||||
|
delta1 = Mock(spec=["type", "text"])
|
||||||
|
delta1.type = "text_delta"
|
||||||
|
delta1.text = "Answer: "
|
||||||
|
|
||||||
|
event1 = Mock(spec=["type", "delta"])
|
||||||
|
event1.type = "content_block_delta"
|
||||||
|
event1.delta = delta1
|
||||||
|
|
||||||
|
event1_result, _ = provider._handle_stream_event(event1, None)
|
||||||
|
assert event1_result.type == "text"
|
||||||
|
|
||||||
|
delta2 = Mock(spec=["type", "thinking"])
|
||||||
|
delta2.type = "thinking_delta"
|
||||||
|
delta2.thinking = "reasoning..."
|
||||||
|
|
||||||
|
event2 = Mock(spec=["type", "delta"])
|
||||||
|
event2.type = "content_block_delta"
|
||||||
|
event2.delta = delta2
|
||||||
|
|
||||||
|
event2_result, _ = provider._handle_stream_event(event2, None)
|
||||||
|
assert event2_result.type == "thinking"
|
440
tests/memory/common/llms/test_anthropic_provider.py
Normal file
440
tests/memory/common/llms/test_anthropic_provider.py
Normal file
@ -0,0 +1,440 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
|
from memory.common.llms.base import (
|
||||||
|
Message,
|
||||||
|
MessageRole,
|
||||||
|
TextContent,
|
||||||
|
ImageContent,
|
||||||
|
ToolUseContent,
|
||||||
|
ToolResultContent,
|
||||||
|
ThinkingContent,
|
||||||
|
LLMSettings,
|
||||||
|
StreamEvent,
|
||||||
|
)
|
||||||
|
from memory.common.llms.tools import ToolDefinition
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider():
|
||||||
|
return AnthropicProvider(api_key="test-key", model="claude-3-opus-20240229")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def thinking_provider():
|
||||||
|
return AnthropicProvider(
|
||||||
|
api_key="test-key", model="claude-opus-4", enable_thinking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialization(provider):
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.model == "claude-3-opus-20240229"
|
||||||
|
assert provider.enable_thinking is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_lazy_loading(provider):
|
||||||
|
assert provider._client is None
|
||||||
|
client = provider.client
|
||||||
|
assert client is not None
|
||||||
|
assert provider._client is not None
|
||||||
|
# Second call should return same instance
|
||||||
|
assert provider.client is client
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_client_lazy_loading(provider):
|
||||||
|
assert provider._async_client is None
|
||||||
|
client = provider.async_client
|
||||||
|
assert client is not None
|
||||||
|
assert provider._async_client is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected",
|
||||||
|
[
|
||||||
|
("claude-opus-4", True),
|
||||||
|
("claude-opus-4-1", True),
|
||||||
|
("claude-sonnet-4-0", True),
|
||||||
|
("claude-sonnet-3-7", True),
|
||||||
|
("claude-sonnet-4-5", True),
|
||||||
|
("claude-3-opus-20240229", False),
|
||||||
|
("claude-3-sonnet-20240229", False),
|
||||||
|
("gpt-4", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_supports_thinking(model, expected):
|
||||||
|
provider = AnthropicProvider(api_key="test-key", model=model)
|
||||||
|
assert provider._supports_thinking() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_text_content(provider):
|
||||||
|
content = TextContent(text="hello world")
|
||||||
|
result = provider._convert_text_content(content)
|
||||||
|
assert result == {"type": "text", "text": "hello world"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_image_content(provider):
|
||||||
|
image = Image.new("RGB", (100, 100), color="red")
|
||||||
|
content = ImageContent(image=image)
|
||||||
|
result = provider._convert_image_content(content)
|
||||||
|
|
||||||
|
assert result["type"] == "image"
|
||||||
|
assert result["source"]["type"] == "base64"
|
||||||
|
assert result["source"]["media_type"] == "image/jpeg"
|
||||||
|
assert isinstance(result["source"]["data"], str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_include_message_filters_system(provider):
|
||||||
|
system_msg = Message(role=MessageRole.SYSTEM, content="system prompt")
|
||||||
|
user_msg = Message(role=MessageRole.USER, content="user message")
|
||||||
|
|
||||||
|
assert provider._should_include_message(system_msg) is False
|
||||||
|
assert provider._should_include_message(user_msg) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"messages, expected_count",
|
||||||
|
[
|
||||||
|
([Message(role=MessageRole.USER, content="test")], 1),
|
||||||
|
([Message(role=MessageRole.SYSTEM, content="test")], 0),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
Message(role=MessageRole.SYSTEM, content="system"),
|
||||||
|
Message(role=MessageRole.USER, content="user"),
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_convert_messages(provider, messages, expected_count):
|
||||||
|
result = provider._convert_messages(messages)
|
||||||
|
assert len(result) == expected_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool(provider):
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name="test_tool",
|
||||||
|
description="A test tool",
|
||||||
|
input_schema={"type": "object", "properties": {}},
|
||||||
|
function=lambda x: "result",
|
||||||
|
)
|
||||||
|
result = provider._convert_tool(tool)
|
||||||
|
|
||||||
|
assert result["name"] == "test_tool"
|
||||||
|
assert result["description"] == "A test tool"
|
||||||
|
assert result["input_schema"] == {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_basic(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||||
|
|
||||||
|
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
||||||
|
|
||||||
|
assert kwargs["model"] == "claude-3-opus-20240229"
|
||||||
|
assert kwargs["temperature"] == 0.5
|
||||||
|
assert kwargs["max_tokens"] == 1000
|
||||||
|
assert len(kwargs["messages"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_with_system_prompt(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings()
|
||||||
|
|
||||||
|
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
||||||
|
|
||||||
|
assert kwargs["system"] == "system prompt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_with_tools(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
tools = [
|
||||||
|
ToolDefinition(
|
||||||
|
name="test",
|
||||||
|
description="test",
|
||||||
|
input_schema={},
|
||||||
|
function=lambda x: "result",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
settings = LLMSettings()
|
||||||
|
|
||||||
|
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
||||||
|
|
||||||
|
assert "tools" in kwargs
|
||||||
|
assert len(kwargs["tools"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_with_thinking(thinking_provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings(max_tokens=5000)
|
||||||
|
|
||||||
|
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
||||||
|
|
||||||
|
assert "thinking" in kwargs
|
||||||
|
assert kwargs["thinking"]["type"] == "enabled"
|
||||||
|
assert kwargs["thinking"]["budget_tokens"] == 3976
|
||||||
|
assert kwargs["temperature"] == 1.0
|
||||||
|
assert "top_p" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings(max_tokens=1000)
|
||||||
|
|
||||||
|
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
||||||
|
|
||||||
|
# Shouldn't enable thinking if not enough tokens
|
||||||
|
assert "thinking" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_event_text_delta(provider):
|
||||||
|
event = Mock(
|
||||||
|
type="content_block_delta",
|
||||||
|
delta=Mock(type="text_delta", text="hello"),
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is not None
|
||||||
|
assert stream_event.type == "text"
|
||||||
|
assert stream_event.data == "hello"
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_event_thinking_delta(provider):
|
||||||
|
event = Mock(
|
||||||
|
type="content_block_delta",
|
||||||
|
delta=Mock(type="thinking_delta", thinking="reasoning..."),
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is not None
|
||||||
|
assert stream_event.type == "thinking"
|
||||||
|
assert stream_event.data == "reasoning..."
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_event_tool_use_start(provider):
|
||||||
|
block = Mock(spec=["type", "id", "name", "input"])
|
||||||
|
block.type = "tool_use"
|
||||||
|
block.id = "tool-1"
|
||||||
|
block.name = "test_tool"
|
||||||
|
block.input = {}
|
||||||
|
|
||||||
|
event = Mock(spec=["type", "content_block"])
|
||||||
|
event.type = "content_block_start"
|
||||||
|
event.content_block = block
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
assert tool_use is not None
|
||||||
|
assert tool_use["id"] == "tool-1"
|
||||||
|
assert tool_use["name"] == "test_tool"
|
||||||
|
assert tool_use["input"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_event_tool_input_delta(provider):
|
||||||
|
current_tool = {"id": "tool-1", "name": "test", "input": '{"ke'}
|
||||||
|
event = Mock(
|
||||||
|
type="content_block_delta",
|
||||||
|
delta=Mock(type="input_json_delta", partial_json='y": "val'),
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
|
||||||
|
|
||||||
|
assert stream_event is None
|
||||||
|
assert tool_use["input"] == '{"key": "val'
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_event_tool_use_complete(provider):
|
||||||
|
current_tool = {
|
||||||
|
"id": "tool-1",
|
||||||
|
"name": "test_tool",
|
||||||
|
"input": '{"key": "value"}',
|
||||||
|
}
|
||||||
|
event = Mock(
|
||||||
|
type="content_block_stop",
|
||||||
|
content_block=Mock(input={"key": "value"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, current_tool)
|
||||||
|
|
||||||
|
assert stream_event is not None
|
||||||
|
assert stream_event.type == "tool_use"
|
||||||
|
assert stream_event.data["id"] == "tool-1"
|
||||||
|
assert stream_event.data["name"] == "test_tool"
|
||||||
|
assert stream_event.data["input"] == {"key": "value"}
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_event_message_stop(provider):
|
||||||
|
event = Mock(type="message_stop")
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is not None
|
||||||
|
assert stream_event.type == "done"
|
||||||
|
assert tool_use is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_event_error(provider):
|
||||||
|
event = Mock(type="error", error="API error")
|
||||||
|
|
||||||
|
stream_event, tool_use = provider._handle_stream_event(event, None)
|
||||||
|
|
||||||
|
assert stream_event is not None
|
||||||
|
assert stream_event.type == "error"
|
||||||
|
assert "API error" in stream_event.data
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_basic(provider, mock_anthropic_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
# Mock the response properly
|
||||||
|
mock_block = Mock(spec=["type", "text"])
|
||||||
|
mock_block.type = "text"
|
||||||
|
mock_block.text = "test summary"
|
||||||
|
|
||||||
|
mock_response = Mock(spec=["content"])
|
||||||
|
mock_response.content = [mock_block]
|
||||||
|
|
||||||
|
provider.client.messages.create.return_value = mock_response
|
||||||
|
|
||||||
|
result = provider.generate(messages)
|
||||||
|
|
||||||
|
assert result == "test summary"
|
||||||
|
provider.client.messages.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_basic(provider, mock_anthropic_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
events = list(provider.stream(messages))
|
||||||
|
|
||||||
|
# Should get text event and done event
|
||||||
|
assert len(events) > 0
|
||||||
|
assert any(e.type == "text" for e in events)
|
||||||
|
provider.client.messages.stream.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agenerate_basic(provider, mock_anthropic_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
result = await provider.agenerate(messages)
|
||||||
|
|
||||||
|
assert result == "test summary"
|
||||||
|
provider.async_client.messages.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_astream_basic(provider, mock_anthropic_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in provider.astream(messages):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert len(events) > 0
|
||||||
|
assert any(e.type == "text" for e in events)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_message_sorts_thinking_content(provider):
|
||||||
|
"""Thinking content should be sorted so non-thinking comes before thinking."""
|
||||||
|
message = Message.assistant(
|
||||||
|
ThinkingContent(thinking="reasoning", signature="sig"),
|
||||||
|
TextContent(text="response"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = provider._convert_message(message)
|
||||||
|
|
||||||
|
assert result["role"] == "assistant"
|
||||||
|
# The sort key (x["type"] != "thinking") sorts thinking type to beginning
|
||||||
|
# because "thinking" != "thinking" is False, which sorts before True
|
||||||
|
content_types = [c["type"] for c in result["content"]]
|
||||||
|
assert "text" in content_types
|
||||||
|
assert "thinking" in content_types
|
||||||
|
# Verify thinking comes before non-thinking (sorted by key)
|
||||||
|
thinking_idx = content_types.index("thinking")
|
||||||
|
text_idx = content_types.index("text")
|
||||||
|
assert thinking_idx < text_idx
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_tool_success(provider):
|
||||||
|
tool_call = {"id": "t1", "name": "test", "input": {"arg": "value"}}
|
||||||
|
tools = {
|
||||||
|
"test": ToolDefinition(
|
||||||
|
name="test",
|
||||||
|
description="test",
|
||||||
|
input_schema={},
|
||||||
|
function=lambda x: f"result: {x['arg']}",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = provider.execute_tool(tool_call, tools)
|
||||||
|
|
||||||
|
assert result.tool_use_id == "t1"
|
||||||
|
assert result.content == "result: value"
|
||||||
|
assert result.is_error is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_tool_missing_name(provider):
|
||||||
|
tool_call = {"id": "t1", "input": {}}
|
||||||
|
tools = {}
|
||||||
|
|
||||||
|
result = provider.execute_tool(tool_call, tools)
|
||||||
|
|
||||||
|
assert result.tool_use_id == "t1"
|
||||||
|
assert "missing" in result.content.lower()
|
||||||
|
assert result.is_error is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_tool_not_found(provider):
|
||||||
|
tool_call = {"id": "t1", "name": "nonexistent", "input": {}}
|
||||||
|
tools = {}
|
||||||
|
|
||||||
|
result = provider.execute_tool(tool_call, tools)
|
||||||
|
|
||||||
|
assert result.tool_use_id == "t1"
|
||||||
|
assert "not found" in result.content.lower()
|
||||||
|
assert result.is_error is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_tool_exception(provider):
|
||||||
|
tool_call = {"id": "t1", "name": "test", "input": {}}
|
||||||
|
tools = {
|
||||||
|
"test": ToolDefinition(
|
||||||
|
name="test",
|
||||||
|
description="test",
|
||||||
|
input_schema={},
|
||||||
|
function=lambda x: 1 / 0, # Raises ZeroDivisionError
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = provider.execute_tool(tool_call, tools)
|
||||||
|
|
||||||
|
assert result.tool_use_id == "t1"
|
||||||
|
assert result.is_error is True
|
||||||
|
assert "division" in result.content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_image(provider):
|
||||||
|
image = Image.new("RGB", (10, 10), color="blue")
|
||||||
|
|
||||||
|
encoded = provider.encode_image(image)
|
||||||
|
|
||||||
|
assert isinstance(encoded, str)
|
||||||
|
assert len(encoded) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_image_rgba(provider):
|
||||||
|
"""RGBA images should be converted to RGB."""
|
||||||
|
image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128))
|
||||||
|
|
||||||
|
encoded = provider.encode_image(image)
|
||||||
|
|
||||||
|
assert isinstance(encoded, str)
|
||||||
|
assert len(encoded) > 0
|
270
tests/memory/common/llms/test_base.py
Normal file
270
tests/memory/common/llms/test_base.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common.llms.base import (
|
||||||
|
Message,
|
||||||
|
MessageRole,
|
||||||
|
TextContent,
|
||||||
|
ImageContent,
|
||||||
|
ToolUseContent,
|
||||||
|
ToolResultContent,
|
||||||
|
ThinkingContent,
|
||||||
|
LLMSettings,
|
||||||
|
StreamEvent,
|
||||||
|
create_provider,
|
||||||
|
)
|
||||||
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_role_enum():
|
||||||
|
assert MessageRole.SYSTEM == "system"
|
||||||
|
assert MessageRole.USER == "user"
|
||||||
|
assert MessageRole.ASSISTANT == "assistant"
|
||||||
|
assert MessageRole.TOOL == "tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_content_creation():
|
||||||
|
content = TextContent(text="hello")
|
||||||
|
assert content.type == "text"
|
||||||
|
assert content.text == "hello"
|
||||||
|
assert content.valid
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_content_to_dict():
|
||||||
|
content = TextContent(text="hello")
|
||||||
|
result = content.to_dict()
|
||||||
|
assert result == {"type": "text", "text": "hello"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_content_empty_invalid():
|
||||||
|
content = TextContent(text="")
|
||||||
|
assert not content.valid
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_content_creation():
|
||||||
|
image = Image.new("RGB", (10, 10))
|
||||||
|
content = ImageContent(image=image)
|
||||||
|
assert content.type == "image"
|
||||||
|
assert content.image == image
|
||||||
|
assert content.valid
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_content_with_detail():
|
||||||
|
image = Image.new("RGB", (10, 10))
|
||||||
|
content = ImageContent(image=image, detail="high")
|
||||||
|
assert content.detail == "high"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use_content_creation():
|
||||||
|
content = ToolUseContent(id="t1", name="test_tool", input={"arg": "value"})
|
||||||
|
assert content.type == "tool_use"
|
||||||
|
assert content.id == "t1"
|
||||||
|
assert content.name == "test_tool"
|
||||||
|
assert content.input == {"arg": "value"}
|
||||||
|
assert content.valid
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use_content_to_dict():
|
||||||
|
content = ToolUseContent(id="t1", name="test", input={"key": "val"})
|
||||||
|
result = content.to_dict()
|
||||||
|
assert result == {
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "t1",
|
||||||
|
"name": "test",
|
||||||
|
"input": {"key": "val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_content_creation():
|
||||||
|
content = ToolResultContent(
|
||||||
|
tool_use_id="t1",
|
||||||
|
content="result",
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
assert content.type == "tool_result"
|
||||||
|
assert content.tool_use_id == "t1"
|
||||||
|
assert content.content == "result"
|
||||||
|
assert not content.is_error
|
||||||
|
assert content.valid
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_content_with_error():
|
||||||
|
content = ToolResultContent(
|
||||||
|
tool_use_id="t1",
|
||||||
|
content="error message",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
assert content.is_error
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_content_creation():
|
||||||
|
content = ThinkingContent(thinking="reasoning...", signature="sig")
|
||||||
|
assert content.type == "thinking"
|
||||||
|
assert content.thinking == "reasoning..."
|
||||||
|
assert content.signature == "sig"
|
||||||
|
assert content.valid
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_content_invalid_without_signature():
|
||||||
|
content = ThinkingContent(thinking="reasoning...")
|
||||||
|
assert not content.valid
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_simple_string_content():
|
||||||
|
msg = Message(role=MessageRole.USER, content="hello")
|
||||||
|
assert msg.role == MessageRole.USER
|
||||||
|
assert msg.content == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_list_content():
|
||||||
|
content_list = [TextContent(text="hello"), TextContent(text="world")]
|
||||||
|
msg = Message(role=MessageRole.USER, content=content_list)
|
||||||
|
assert msg.role == MessageRole.USER
|
||||||
|
assert len(msg.content) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_to_dict_string():
|
||||||
|
msg = Message(role=MessageRole.USER, content="hello")
|
||||||
|
result = msg.to_dict()
|
||||||
|
assert result == {"role": "user", "content": "hello"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_to_dict_list():
|
||||||
|
msg = Message(
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content=[TextContent(text="hello"), TextContent(text="world")],
|
||||||
|
)
|
||||||
|
result = msg.to_dict()
|
||||||
|
assert result["role"] == "user"
|
||||||
|
assert len(result["content"]) == 2
|
||||||
|
assert result["content"][0] == {"type": "text", "text": "hello"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_assistant_factory():
|
||||||
|
msg = Message.assistant(
|
||||||
|
TextContent(text="response"),
|
||||||
|
ToolUseContent(id="t1", name="tool", input={}),
|
||||||
|
)
|
||||||
|
assert msg.role == MessageRole.ASSISTANT
|
||||||
|
assert len(msg.content) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_assistant_filters_invalid_content():
|
||||||
|
msg = Message.assistant(
|
||||||
|
TextContent(text="valid"),
|
||||||
|
TextContent(text=""), # Invalid - empty
|
||||||
|
)
|
||||||
|
assert len(msg.content) == 1
|
||||||
|
assert msg.content[0].text == "valid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_user_factory():
|
||||||
|
msg = Message.user(text="hello")
|
||||||
|
assert msg.role == MessageRole.USER
|
||||||
|
assert len(msg.content) == 1
|
||||||
|
assert isinstance(msg.content[0], TextContent)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_user_with_tool_result():
|
||||||
|
tool_result = ToolResultContent(tool_use_id="t1", content="result")
|
||||||
|
msg = Message.user(text="hello", tool_result=tool_result)
|
||||||
|
assert len(msg.content) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_event_creation():
|
||||||
|
event = StreamEvent(type="text", data="hello")
|
||||||
|
assert event.type == "text"
|
||||||
|
assert event.data == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_event_with_signature():
|
||||||
|
event = StreamEvent(type="thinking", signature="sig123")
|
||||||
|
assert event.signature == "sig123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_settings_defaults():
|
||||||
|
settings = LLMSettings()
|
||||||
|
assert settings.temperature == 0.7
|
||||||
|
assert settings.max_tokens == 2048
|
||||||
|
assert settings.top_p is None
|
||||||
|
assert settings.stop_sequences is None
|
||||||
|
assert settings.stream is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_settings_custom():
|
||||||
|
settings = LLMSettings(
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=1000,
|
||||||
|
top_p=0.9,
|
||||||
|
stop_sequences=["STOP"],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
assert settings.temperature == 0.5
|
||||||
|
assert settings.max_tokens == 1000
|
||||||
|
assert settings.top_p == 0.9
|
||||||
|
assert settings.stop_sequences == ["STOP"]
|
||||||
|
assert settings.stream is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_provider_anthropic():
|
||||||
|
provider = create_provider(
|
||||||
|
model="anthropic/claude-3-opus-20240229",
|
||||||
|
api_key="test-key",
|
||||||
|
)
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
assert provider.model == "claude-3-opus-20240229"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_provider_openai():
|
||||||
|
provider = create_provider(
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
api_key="test-key",
|
||||||
|
)
|
||||||
|
assert isinstance(provider, OpenAIProvider)
|
||||||
|
assert provider.model == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_provider_unknown_raises():
|
||||||
|
with pytest.raises(ValueError, match="Unknown provider"):
|
||||||
|
create_provider(model="unknown/model", api_key="test-key")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_provider_uses_default_model():
|
||||||
|
"""If no model provided, should use SUMMARIZER_MODEL from settings."""
|
||||||
|
provider = create_provider(api_key="test-key")
|
||||||
|
# Should create a provider (type depends on settings.SUMMARIZER_MODEL)
|
||||||
|
assert provider is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_provider_anthropic_with_thinking():
|
||||||
|
provider = create_provider(
|
||||||
|
model="anthropic/claude-opus-4",
|
||||||
|
api_key="test-key",
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
assert provider.enable_thinking is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_provider_missing_anthropic_key():
|
||||||
|
# Temporarily clear the API key from settings
|
||||||
|
original_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.ANTHROPIC_API_KEY = ""
|
||||||
|
with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"):
|
||||||
|
create_provider(model="anthropic/claude-3-opus-20240229")
|
||||||
|
finally:
|
||||||
|
settings.ANTHROPIC_API_KEY = original_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_provider_missing_openai_key():
|
||||||
|
# Temporarily clear the API key from settings
|
||||||
|
original_key = settings.OPENAI_API_KEY
|
||||||
|
try:
|
||||||
|
settings.OPENAI_API_KEY = ""
|
||||||
|
with pytest.raises(ValueError, match="OPENAI_API_KEY"):
|
||||||
|
create_provider(model="openai/gpt-4o")
|
||||||
|
finally:
|
||||||
|
settings.OPENAI_API_KEY = original_key
|
478
tests/memory/common/llms/test_openai_event_parsing.py
Normal file
478
tests/memory/common/llms/test_openai_event_parsing.py
Normal file
@ -0,0 +1,478 @@
|
|||||||
|
"""Comprehensive tests for OpenAI stream chunk parsing."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
|
from memory.common.llms.base import StreamEvent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider():
|
||||||
|
return OpenAIProvider(api_key="test-key", model="gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
# Text Content Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content,expected_events",
|
||||||
|
[
|
||||||
|
("Hello", 1),
|
||||||
|
("", 0), # Empty string is falsy
|
||||||
|
(None, 0),
|
||||||
|
("Line 1\nLine 2\nLine 3", 1),
|
||||||
|
("Hello 世界 🌍", 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_content(provider, content, expected_events):
|
||||||
|
"""Text content should emit text events appropriately."""
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = content
|
||||||
|
delta.tool_calls = None
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == expected_events
|
||||||
|
if expected_events > 0:
|
||||||
|
assert events[0].type == "text"
|
||||||
|
assert events[0].data == content
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
# Tool Call Start Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_new_tool_call_basic(provider):
|
||||||
|
"""New tool call should initialize state."""
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = "search"
|
||||||
|
function.arguments = ""
|
||||||
|
|
||||||
|
tool = Mock(spec=["id", "function"])
|
||||||
|
tool.id = "call_123"
|
||||||
|
tool.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = [tool]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call == {"id": "call_123", "name": "search", "arguments": ""}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name,arguments,expected_name,expected_args",
|
||||||
|
[
|
||||||
|
("calculate", '{"operation":', "calculate", '{"operation":'),
|
||||||
|
(None, "", "", ""),
|
||||||
|
("test", None, "test", ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_new_tool_call_variations(
|
||||||
|
provider, name, arguments, expected_name, expected_args
|
||||||
|
):
|
||||||
|
"""Tool calls with various name/argument combinations."""
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = name
|
||||||
|
function.arguments = arguments
|
||||||
|
|
||||||
|
tool = Mock(spec=["id", "function"])
|
||||||
|
tool.id = "call_123"
|
||||||
|
tool.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = [tool]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert tool_call["name"] == expected_name
|
||||||
|
assert tool_call["arguments"] == expected_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_new_tool_call_replaces_previous(provider):
|
||||||
|
"""New tool call should finalize and replace previous."""
|
||||||
|
current = {"id": "call_old", "name": "old_tool", "arguments": '{"arg": "value"}'}
|
||||||
|
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = "new_tool"
|
||||||
|
function.arguments = ""
|
||||||
|
|
||||||
|
tool = Mock(spec=["id", "function"])
|
||||||
|
tool.id = "call_new"
|
||||||
|
tool.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = [tool]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, current)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == "tool_use"
|
||||||
|
assert events[0].data["id"] == "call_old"
|
||||||
|
assert events[0].data["input"] == {"arg": "value"}
|
||||||
|
assert tool_call["id"] == "call_new"
|
||||||
|
|
||||||
|
|
||||||
|
# Tool Call Continuation Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"initial_args,new_args,expected_args",
|
||||||
|
[
|
||||||
|
('{"query": "', 'test query"}', '{"query": "test query"}'),
|
||||||
|
('{"query"', ': "value"}', '{"query": "value"}'),
|
||||||
|
("", '{"full": "json"}', '{"full": "json"}'),
|
||||||
|
('{"partial"', "", '{"partial"'), # Empty doesn't accumulate
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_tool_call_argument_accumulation(
|
||||||
|
provider, initial_args, new_args, expected_args
|
||||||
|
):
|
||||||
|
"""Arguments should accumulate correctly."""
|
||||||
|
current = {"id": "call_123", "name": "search", "arguments": initial_args}
|
||||||
|
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = None
|
||||||
|
function.arguments = new_args
|
||||||
|
|
||||||
|
tool = Mock(spec=["id", "function"])
|
||||||
|
tool.id = None
|
||||||
|
tool.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = [tool]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, current)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call["arguments"] == expected_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_call_accumulation_without_current_tool(provider):
|
||||||
|
"""Arguments without current tool should be ignored."""
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = None
|
||||||
|
function.arguments = '{"arg": "value"}'
|
||||||
|
|
||||||
|
tool = Mock(spec=["id", "function"])
|
||||||
|
tool.id = None
|
||||||
|
tool.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = [tool]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_incremental_json_building(provider):
|
||||||
|
"""Test realistic incremental JSON building across multiple chunks."""
|
||||||
|
current = {"id": "c1", "name": "search", "arguments": ""}
|
||||||
|
|
||||||
|
increments = ['{"', 'query":', ' "test"}']
|
||||||
|
expected_states = ['{"', '{"query":', '{"query": "test"}']
|
||||||
|
|
||||||
|
for increment, expected in zip(increments, expected_states):
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = None
|
||||||
|
function.arguments = increment
|
||||||
|
|
||||||
|
tool = Mock(spec=["id", "function"])
|
||||||
|
tool.id = None
|
||||||
|
tool.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = [tool]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
_, current = provider._handle_stream_chunk(chunk, current)
|
||||||
|
assert current["arguments"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
# Finish Reason Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_reason_without_tool(provider):
|
||||||
|
"""Stop finish without tool should not emit events."""
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = None
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = "stop"
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"arguments,expected_input",
|
||||||
|
[
|
||||||
|
('{"query": "test"}', {"query": "test"}),
|
||||||
|
('{"invalid": json}', {}),
|
||||||
|
("", {}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_finish_reason_with_tool(provider, arguments, expected_input):
|
||||||
|
"""Finish with tool call should finalize and emit."""
|
||||||
|
current = {"id": "call_123", "name": "search", "arguments": arguments}
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = None
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, current)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == "tool_use"
|
||||||
|
assert events[0].data["id"] == "call_123"
|
||||||
|
assert events[0].data["input"] == expected_input
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("reason", ["stop", "length", "content_filter", "tool_calls"])
|
||||||
|
def test_various_finish_reasons(provider, reason):
|
||||||
|
"""Various finish reasons with active tool should finalize."""
|
||||||
|
current = {"id": "call_123", "name": "test", "arguments": '{"a": 1}'}
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = None
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = reason
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, current)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
# Edge Cases Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_choices(provider):
|
||||||
|
"""Empty choices list should return empty events."""
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = []
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_none_choices(provider):
|
||||||
|
"""None choices should be handled gracefully."""
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
assert len(events) == 0
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
pass # Also acceptable for malformed input
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_chunks_in_sequence(provider):
|
||||||
|
"""Test processing multiple chunks sequentially."""
|
||||||
|
# Chunk 1: Start
|
||||||
|
function1 = Mock(spec=["name", "arguments"])
|
||||||
|
function1.name = "search"
|
||||||
|
function1.arguments = ""
|
||||||
|
|
||||||
|
tool1 = Mock(spec=["id", "function"])
|
||||||
|
tool1.id = "call_1"
|
||||||
|
tool1.function = function1
|
||||||
|
|
||||||
|
delta1 = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta1.content = None
|
||||||
|
delta1.tool_calls = [tool1]
|
||||||
|
|
||||||
|
choice1 = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice1.delta = delta1
|
||||||
|
choice1.finish_reason = None
|
||||||
|
|
||||||
|
chunk1 = Mock(spec=["choices"])
|
||||||
|
chunk1.choices = [choice1]
|
||||||
|
|
||||||
|
events1, state = provider._handle_stream_chunk(chunk1, None)
|
||||||
|
assert len(events1) == 0
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
# Chunk 2: Args
|
||||||
|
function2 = Mock(spec=["name", "arguments"])
|
||||||
|
function2.name = None
|
||||||
|
function2.arguments = '{"q": "test"}'
|
||||||
|
|
||||||
|
tool2 = Mock(spec=["id", "function"])
|
||||||
|
tool2.id = None
|
||||||
|
tool2.function = function2
|
||||||
|
|
||||||
|
delta2 = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta2.content = None
|
||||||
|
delta2.tool_calls = [tool2]
|
||||||
|
|
||||||
|
choice2 = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice2.delta = delta2
|
||||||
|
choice2.finish_reason = None
|
||||||
|
|
||||||
|
chunk2 = Mock(spec=["choices"])
|
||||||
|
chunk2.choices = [choice2]
|
||||||
|
|
||||||
|
events2, state = provider._handle_stream_chunk(chunk2, state)
|
||||||
|
assert len(events2) == 0
|
||||||
|
assert state["arguments"] == '{"q": "test"}'
|
||||||
|
|
||||||
|
# Chunk 3: Finish
|
||||||
|
delta3 = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta3.content = None
|
||||||
|
delta3.tool_calls = None
|
||||||
|
|
||||||
|
choice3 = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice3.delta = delta3
|
||||||
|
choice3.finish_reason = "stop"
|
||||||
|
|
||||||
|
chunk3 = Mock(spec=["choices"])
|
||||||
|
chunk3.choices = [choice3]
|
||||||
|
|
||||||
|
events3, state = provider._handle_stream_chunk(chunk3, state)
|
||||||
|
assert len(events3) == 1
|
||||||
|
assert events3[0].type == "tool_use"
|
||||||
|
assert state is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_and_tool_calls_mixed(provider):
|
||||||
|
"""Text content should be emitted before tool initialization."""
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = "search"
|
||||||
|
function.arguments = ""
|
||||||
|
|
||||||
|
tool = Mock(spec=["id", "function"])
|
||||||
|
tool.id = "call_1"
|
||||||
|
tool.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = "Let me search for that."
|
||||||
|
delta.tool_calls = [tool]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == "text"
|
||||||
|
assert events[0].data == "Let me search for that."
|
||||||
|
assert tool_call is not None
|
||||||
|
|
||||||
|
|
||||||
|
# JSON Parsing Tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"arguments,expected_input",
|
||||||
|
[
|
||||||
|
('{"key": "value", "num": 42}', {"key": "value", "num": 42}),
|
||||||
|
("{}", {}),
|
||||||
|
(
|
||||||
|
'{"user": {"name": "John", "tags": ["a", "b"]}, "count": 10}',
|
||||||
|
{"user": {"name": "John", "tags": ["a", "b"]}, "count": 10},
|
||||||
|
),
|
||||||
|
('{"invalid": json}', {}),
|
||||||
|
('{"key": "val', {}),
|
||||||
|
("", {}),
|
||||||
|
('{"text": "Hello 世界 🌍"}', {"text": "Hello 世界 🌍"}),
|
||||||
|
(
|
||||||
|
'{"text": "Line 1\\nLine 2\\t\\tTabbed"}',
|
||||||
|
{"text": "Line 1\nLine 2\t\tTabbed"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_json_parsing(provider, arguments, expected_input):
|
||||||
|
"""Various JSON inputs should be parsed correctly."""
|
||||||
|
tool_call = {"id": "c1", "name": "test", "arguments": arguments}
|
||||||
|
|
||||||
|
result = provider._parse_and_finalize_tool_call(tool_call)
|
||||||
|
|
||||||
|
assert result["input"] == expected_input
|
||||||
|
assert "arguments" not in result
|
561
tests/memory/common/llms/test_openai_provider.py
Normal file
561
tests/memory/common/llms/test_openai_provider.py
Normal file
@ -0,0 +1,561 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
|
from memory.common.llms.base import (
|
||||||
|
Message,
|
||||||
|
MessageRole,
|
||||||
|
TextContent,
|
||||||
|
ImageContent,
|
||||||
|
ToolUseContent,
|
||||||
|
ToolResultContent,
|
||||||
|
LLMSettings,
|
||||||
|
StreamEvent,
|
||||||
|
)
|
||||||
|
from memory.common.llms.tools import ToolDefinition
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider():
|
||||||
|
return OpenAIProvider(api_key="test-key", model="gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reasoning_provider():
|
||||||
|
return OpenAIProvider(api_key="test-key", model="o1-preview")
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialization(provider):
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.model == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_lazy_loading(provider):
|
||||||
|
assert provider._client is None
|
||||||
|
client = provider.client
|
||||||
|
assert client is not None
|
||||||
|
assert provider._client is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_client_lazy_loading(provider):
|
||||||
|
assert provider._async_client is None
|
||||||
|
client = provider.async_client
|
||||||
|
assert client is not None
|
||||||
|
assert provider._async_client is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected",
|
||||||
|
[
|
||||||
|
("gpt-4o", False),
|
||||||
|
("o1-preview", True),
|
||||||
|
("o1-mini", True),
|
||||||
|
("gpt-4-turbo", True),
|
||||||
|
("gpt-3.5-turbo", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_reasoning_model(model, expected):
|
||||||
|
provider = OpenAIProvider(api_key="test-key", model=model)
|
||||||
|
assert provider._is_reasoning_model() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_text_content(provider):
|
||||||
|
content = TextContent(text="hello world")
|
||||||
|
result = provider._convert_text_content(content)
|
||||||
|
assert result == {"type": "text", "text": "hello world"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_image_content(provider):
|
||||||
|
image = Image.new("RGB", (100, 100), color="red")
|
||||||
|
content = ImageContent(image=image)
|
||||||
|
result = provider._convert_image_content(content)
|
||||||
|
|
||||||
|
assert result["type"] == "image_url"
|
||||||
|
assert "image_url" in result
|
||||||
|
assert result["image_url"]["url"].startswith("data:image/jpeg;base64,")
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_image_content_with_detail(provider):
|
||||||
|
image = Image.new("RGB", (100, 100), color="red")
|
||||||
|
content = ImageContent(image=image, detail="high")
|
||||||
|
result = provider._convert_image_content(content)
|
||||||
|
|
||||||
|
assert result["image_url"]["detail"] == "high"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_use_content(provider):
|
||||||
|
content = ToolUseContent(
|
||||||
|
id="t1",
|
||||||
|
name="test_tool",
|
||||||
|
input={"arg": "value"},
|
||||||
|
)
|
||||||
|
result = provider._convert_tool_use_content(content)
|
||||||
|
|
||||||
|
assert result["id"] == "t1"
|
||||||
|
assert result["type"] == "function"
|
||||||
|
assert result["function"]["name"] == "test_tool"
|
||||||
|
assert '{"arg": "value"}' in result["function"]["arguments"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_result_content(provider):
|
||||||
|
content = ToolResultContent(
|
||||||
|
tool_use_id="t1",
|
||||||
|
content="result content",
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
result = provider._convert_tool_result_content(content)
|
||||||
|
|
||||||
|
assert result["role"] == "tool"
|
||||||
|
assert result["tool_call_id"] == "t1"
|
||||||
|
assert result["content"] == "result content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_messages_simple(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
result = provider._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["role"] == "user"
|
||||||
|
assert result[0]["content"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_messages_with_tool_result(provider):
|
||||||
|
"""Tool results should become separate messages with 'tool' role."""
|
||||||
|
messages = [
|
||||||
|
Message(
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content=[ToolResultContent(tool_use_id="t1", content="result")],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = provider._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["role"] == "tool"
|
||||||
|
assert result[0]["tool_call_id"] == "t1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_messages_with_tool_use(provider):
|
||||||
|
"""Tool use content should become tool_calls field."""
|
||||||
|
messages = [
|
||||||
|
Message.assistant(
|
||||||
|
TextContent(text="thinking..."),
|
||||||
|
ToolUseContent(id="t1", name="test", input={}),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = provider._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert "tool_calls" in result[0]
|
||||||
|
assert len(result[0]["tool_calls"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_messages_mixed_content(provider):
|
||||||
|
"""Messages with both text and tool results should be split."""
|
||||||
|
messages = [
|
||||||
|
Message(
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content=[
|
||||||
|
TextContent(text="user text"),
|
||||||
|
ToolResultContent(tool_use_id="t1", content="result"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = provider._convert_messages(messages)
|
||||||
|
|
||||||
|
# Should create two messages: one user message and one tool message
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["role"] == "tool"
|
||||||
|
assert result[1]["role"] == "user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tools(provider):
|
||||||
|
tools = [
|
||||||
|
ToolDefinition(
|
||||||
|
name="test_tool",
|
||||||
|
description="A test tool",
|
||||||
|
input_schema={"type": "object", "properties": {"arg": {"type": "string"}}},
|
||||||
|
function=lambda x: "result",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = provider._convert_tools(tools)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["type"] == "function"
|
||||||
|
assert result[0]["function"]["name"] == "test_tool"
|
||||||
|
assert result[0]["function"]["description"] == "A test tool"
|
||||||
|
assert result[0]["function"]["parameters"] == tools[0].input_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_basic(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||||
|
|
||||||
|
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
||||||
|
|
||||||
|
assert kwargs["model"] == "gpt-4o"
|
||||||
|
assert kwargs["temperature"] == 0.5
|
||||||
|
assert kwargs["max_tokens"] == 1000
|
||||||
|
assert len(kwargs["messages"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_with_system_prompt_standard_model(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings()
|
||||||
|
|
||||||
|
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
||||||
|
|
||||||
|
# For gpt-4o, system prompt becomes system message
|
||||||
|
assert kwargs["messages"][0]["role"] == "system"
|
||||||
|
assert kwargs["messages"][0]["content"] == "system prompt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_with_system_prompt_reasoning_model(
|
||||||
|
reasoning_provider,
|
||||||
|
):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings()
|
||||||
|
|
||||||
|
kwargs = reasoning_provider._build_request_kwargs(
|
||||||
|
messages, "system prompt", None, settings
|
||||||
|
)
|
||||||
|
|
||||||
|
# For o1 models, system prompt becomes developer message
|
||||||
|
assert kwargs["messages"][0]["role"] == "developer"
|
||||||
|
assert kwargs["messages"][0]["content"] == "system prompt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens(
|
||||||
|
reasoning_provider,
|
||||||
|
):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings(max_tokens=2000)
|
||||||
|
|
||||||
|
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
||||||
|
|
||||||
|
# Reasoning models use max_completion_tokens
|
||||||
|
assert "max_completion_tokens" in kwargs
|
||||||
|
assert kwargs["max_completion_tokens"] == 2000
|
||||||
|
assert "max_tokens" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings(temperature=0.7)
|
||||||
|
|
||||||
|
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
||||||
|
|
||||||
|
# Reasoning models don't support temperature
|
||||||
|
assert "temperature" not in kwargs
|
||||||
|
assert "top_p" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_with_tools(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
tools = [
|
||||||
|
ToolDefinition(
|
||||||
|
name="test",
|
||||||
|
description="test",
|
||||||
|
input_schema={},
|
||||||
|
function=lambda x: "result",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
settings = LLMSettings()
|
||||||
|
|
||||||
|
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
||||||
|
|
||||||
|
assert "tools" in kwargs
|
||||||
|
assert len(kwargs["tools"]) == 1
|
||||||
|
assert kwargs["tool_choice"] == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_kwargs_with_stream(provider):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
settings = LLMSettings()
|
||||||
|
|
||||||
|
kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True)
|
||||||
|
|
||||||
|
assert kwargs["stream"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_finalize_tool_call(provider):
|
||||||
|
tool_call = {
|
||||||
|
"id": "t1",
|
||||||
|
"name": "test",
|
||||||
|
"arguments": '{"key": "value"}',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = provider._parse_and_finalize_tool_call(tool_call)
|
||||||
|
|
||||||
|
assert result["id"] == "t1"
|
||||||
|
assert result["name"] == "test"
|
||||||
|
assert result["input"] == {"key": "value"}
|
||||||
|
assert "arguments" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_finalize_tool_call_invalid_json(provider):
|
||||||
|
tool_call = {
|
||||||
|
"id": "t1",
|
||||||
|
"name": "test",
|
||||||
|
"arguments": '{"invalid json',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = provider._parse_and_finalize_tool_call(tool_call)
|
||||||
|
|
||||||
|
# Should default to empty dict on parse error
|
||||||
|
assert result["input"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_chunk_text_content(provider):
|
||||||
|
chunk = Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(
|
||||||
|
delta=Mock(content="hello", tool_calls=None),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == "text"
|
||||||
|
assert events[0].data == "hello"
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_chunk_tool_call_start(provider):
|
||||||
|
function = Mock(spec=["name", "arguments"])
|
||||||
|
function.name = "test_tool"
|
||||||
|
function.arguments = ""
|
||||||
|
|
||||||
|
tool_call_mock = Mock(spec=["id", "function"])
|
||||||
|
tool_call_mock.id = "t1"
|
||||||
|
tool_call_mock.function = function
|
||||||
|
|
||||||
|
delta = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta.content = None
|
||||||
|
delta.tool_calls = [tool_call_mock]
|
||||||
|
|
||||||
|
choice = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice.delta = delta
|
||||||
|
choice.finish_reason = None
|
||||||
|
|
||||||
|
chunk = Mock(spec=["choices"])
|
||||||
|
chunk.choices = [choice]
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call is not None
|
||||||
|
assert tool_call["id"] == "t1"
|
||||||
|
assert tool_call["name"] == "test_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_chunk_tool_call_arguments(provider):
|
||||||
|
current_tool = {"id": "t1", "name": "test", "arguments": '{"ke'}
|
||||||
|
chunk = Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(
|
||||||
|
delta=Mock(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
Mock(
|
||||||
|
id=None,
|
||||||
|
function=Mock(name=None, arguments='y": "val"}'),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call["arguments"] == '{"key": "val"}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_chunk_finish_with_tool_call(provider):
|
||||||
|
current_tool = {"id": "t1", "name": "test", "arguments": '{"key": "value"}'}
|
||||||
|
chunk = Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(
|
||||||
|
delta=Mock(content=None, tool_calls=None),
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == "tool_use"
|
||||||
|
assert events[0].data["id"] == "t1"
|
||||||
|
assert events[0].data["input"] == {"key": "value"}
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_stream_chunk_empty_choices(provider):
|
||||||
|
chunk = Mock(choices=[])
|
||||||
|
|
||||||
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_basic(provider, mock_openai_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
# The conftest fixture already sets up the mock response
|
||||||
|
result = provider.generate(messages)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
provider.client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_basic(provider, mock_openai_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
events = list(provider.stream(messages))
|
||||||
|
|
||||||
|
# Should get text events and done event
|
||||||
|
assert len(events) > 0
|
||||||
|
text_events = [e for e in events if e.type == "text"]
|
||||||
|
assert len(text_events) > 0
|
||||||
|
assert events[-1].type == "done"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agenerate_basic(provider, mock_openai_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
# Mock the async client
|
||||||
|
mock_response = Mock(choices=[Mock(message=Mock(content="async response"))])
|
||||||
|
provider.async_client.chat.completions.create = Mock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await provider.agenerate(messages)
|
||||||
|
|
||||||
|
assert result == "async response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_astream_basic(provider, mock_openai_client):
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
|
# Mock async streaming
|
||||||
|
async def async_stream():
|
||||||
|
yield Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
yield Mock(
|
||||||
|
choices=[
|
||||||
|
Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.async_client.chat.completions.create = Mock(return_value=async_stream())
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in provider.astream(messages):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert len(events) > 0
|
||||||
|
text_events = [e for e in events if e.type == "text"]
|
||||||
|
assert len(text_events) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_with_tool_call(provider, mock_openai_client):
|
||||||
|
"""Test streaming with a complete tool call."""
|
||||||
|
|
||||||
|
def stream_with_tool(*args, **kwargs):
|
||||||
|
if kwargs.get("stream"):
|
||||||
|
# First chunk - tool call start
|
||||||
|
function1 = Mock(spec=["name", "arguments"])
|
||||||
|
function1.name = "test_tool"
|
||||||
|
function1.arguments = ""
|
||||||
|
|
||||||
|
tool_call1 = Mock(spec=["id", "function"])
|
||||||
|
tool_call1.id = "t1"
|
||||||
|
tool_call1.function = function1
|
||||||
|
|
||||||
|
delta1 = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta1.content = None
|
||||||
|
delta1.tool_calls = [tool_call1]
|
||||||
|
|
||||||
|
choice1 = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice1.delta = delta1
|
||||||
|
choice1.finish_reason = None
|
||||||
|
|
||||||
|
chunk1 = Mock(spec=["choices"])
|
||||||
|
chunk1.choices = [choice1]
|
||||||
|
|
||||||
|
# Second chunk - tool arguments
|
||||||
|
function2 = Mock(spec=["name", "arguments"])
|
||||||
|
function2.name = None
|
||||||
|
function2.arguments = '{"arg": "val"}'
|
||||||
|
|
||||||
|
tool_call2 = Mock(spec=["id", "function"])
|
||||||
|
tool_call2.id = None
|
||||||
|
tool_call2.function = function2
|
||||||
|
|
||||||
|
delta2 = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta2.content = None
|
||||||
|
delta2.tool_calls = [tool_call2]
|
||||||
|
|
||||||
|
choice2 = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice2.delta = delta2
|
||||||
|
choice2.finish_reason = None
|
||||||
|
|
||||||
|
chunk2 = Mock(spec=["choices"])
|
||||||
|
chunk2.choices = [choice2]
|
||||||
|
|
||||||
|
# Third chunk - finish
|
||||||
|
delta3 = Mock(spec=["content", "tool_calls"])
|
||||||
|
delta3.content = None
|
||||||
|
delta3.tool_calls = None
|
||||||
|
|
||||||
|
choice3 = Mock(spec=["delta", "finish_reason"])
|
||||||
|
choice3.delta = delta3
|
||||||
|
choice3.finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
chunk3 = Mock(spec=["choices"])
|
||||||
|
chunk3.choices = [choice3]
|
||||||
|
|
||||||
|
return iter([chunk1, chunk2, chunk3])
|
||||||
|
|
||||||
|
provider.client.chat.completions.create.side_effect = stream_with_tool
|
||||||
|
|
||||||
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
events = list(provider.stream(messages))
|
||||||
|
|
||||||
|
tool_events = [e for e in events if e.type == "tool_use"]
|
||||||
|
assert len(tool_events) == 1
|
||||||
|
assert tool_events[0].data["id"] == "t1"
|
||||||
|
assert tool_events[0].data["name"] == "test_tool"
|
||||||
|
assert tool_events[0].data["input"] == {"arg": "val"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_image(provider):
|
||||||
|
image = Image.new("RGB", (10, 10), color="blue")
|
||||||
|
|
||||||
|
encoded = provider.encode_image(image)
|
||||||
|
|
||||||
|
assert isinstance(encoded, str)
|
||||||
|
assert len(encoded) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_image_rgba(provider):
|
||||||
|
"""RGBA images should be converted to RGB."""
|
||||||
|
image = Image.new("RGBA", (10, 10), color=(255, 0, 0, 128))
|
||||||
|
|
||||||
|
encoded = provider.encode_image(image)
|
||||||
|
|
||||||
|
assert isinstance(encoded, str)
|
||||||
|
assert len(encoded) > 0
|
@ -2,21 +2,41 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.users import User
|
from memory.common.db.models.users import HumanUser, BotUser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = argparse.ArgumentParser()
|
args = argparse.ArgumentParser()
|
||||||
args.add_argument("--email", type=str, required=True)
|
args.add_argument("--email", type=str, required=True)
|
||||||
args.add_argument("--password", type=str, required=True)
|
|
||||||
args.add_argument("--name", type=str, required=True)
|
args.add_argument("--name", type=str, required=True)
|
||||||
|
args.add_argument("--password", type=str, required=False)
|
||||||
|
args.add_argument("--bot", action="store_true", help="Create a bot user")
|
||||||
|
args.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="API key for bot user (auto-generated if not provided)",
|
||||||
|
)
|
||||||
args = args.parse_args()
|
args = args.parse_args()
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
user = User.create_with_password(
|
if args.bot:
|
||||||
email=args.email, password=args.password, name=args.name
|
user = BotUser.create_with_api_key(
|
||||||
)
|
name=args.name, email=args.email, api_key=args.api_key
|
||||||
|
)
|
||||||
|
print(f"Bot user {args.email} created with API key: {user.api_key}")
|
||||||
|
else:
|
||||||
|
if not args.password:
|
||||||
|
raise ValueError("Password required for human users")
|
||||||
|
user = HumanUser.create_with_password(
|
||||||
|
email=args.email, password=args.password, name=args.name
|
||||||
|
)
|
||||||
|
print(f"Human user {args.email} created")
|
||||||
|
|
||||||
session.add(user)
|
session.add(user)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
print(f"User {args.email} created")
|
if args.bot:
|
||||||
|
print(f"Bot user {args.email} created with API key: {user.api_key}")
|
||||||
|
else:
|
||||||
|
print(f"Human user {args.email} created")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user