Compare commits

..

4 Commits

Author SHA1 Message Date
Daniel O'Connell
8af07f0dac add slash commands for discord 2025-11-01 18:04:38 +01:00
Daniel O'Connell
c296f3b533 extract usage 2025-11-01 17:56:20 +01:00
Daniel O'Connell
07852f9ee7 Base usage tracker 2025-11-01 16:22:40 +01:00
Daniel O'Connell
bcb470db9b use redis for celery backend 2025-11-01 15:55:59 +01:00
14 changed files with 1141 additions and 43 deletions

5
AGENTS.md Normal file
View File

@ -0,0 +1,5 @@
# Agent Guidance
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.

View File

@ -19,17 +19,17 @@ secrets:
volumes:
db_data: {} # Postgres
qdrant_data: {} # Qdrant
rabbitmq_data: {} # RabbitMQ
redis_data: {} # Redis
# ------------------------------ X-templates ----------------------------
x-common-env: &env
RABBITMQ_USER: kb
RABBITMQ_HOST: rabbitmq
REDIS_HOST: redis
REDIS_PORT: 6379
REDIS_DB: 0
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
QDRANT_HOST: qdrant
DB_HOST: postgres
DB_PORT: 5432
RABBITMQ_PORT: 5672
FILE_STORAGE_DIR: /app/memory_files
TZ: "Etc/UTC"
@ -40,7 +40,7 @@ x-worker-base: &worker-base
restart: unless-stopped
networks: [ kbnet ]
security_opt: [ "no-new-privileges=true" ]
depends_on: [ postgres, rabbitmq, qdrant ]
depends_on: [ postgres, redis, qdrant ]
env_file: [ .env ]
environment: &worker-env
<<: *env
@ -103,22 +103,21 @@ services:
volumes:
- ./db:/app/db:ro
rabbitmq:
image: rabbitmq:3.13-management
redis:
image: redis:7.2-alpine
restart: unless-stopped
networks: [ kbnet ]
environment:
<<: *env
RABBITMQ_DEFAULT_USER: "kb"
RABBITMQ_DEFAULT_PASS: "${CELERY_BROKER_PASSWORD}"
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${CELERY_BROKER_PASSWORD}"]
volumes:
- rabbitmq_data:/var/lib/rabbitmq:rw
- redis_data:/data:rw
healthcheck:
test: [ "CMD", "rabbitmq-diagnostics", "ping" ]
test: [ "CMD", "redis-cli", "--pass", "${CELERY_BROKER_PASSWORD}", "ping" ]
interval: 15s
timeout: 5s
retries: 5
security_opt: [ "no-new-privileges=true" ]
cap_drop: [ ALL ]
user: redis
qdrant:
image: qdrant/qdrant:v1.14.0
@ -148,7 +147,7 @@ services:
SESSION_COOKIE_NAME: "${SESSION_COOKIE_NAME:-session_id}"
restart: unless-stopped
networks: [kbnet]
depends_on: [postgres, rabbitmq, qdrant]
depends_on: [postgres, redis, qdrant]
environment:
<<: *env
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
@ -186,7 +185,7 @@ services:
environment:
<<: *worker-env
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
DISCORD_NOTIFICATIONS_ENABLED: true
DISCORD_NOTIFICATIONS_ENABLED: ${DISCORD_NOTIFICATIONS_ENABLED:-true}
DISCORD_COLLECTOR_ENABLED: true
volumes:
- ./memory_files:/app/memory_files:rw

View File

@ -9,4 +9,4 @@ anthropic==0.69.0
openai==2.3.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0
celery[sqs]==5.3.6
celery[redis,sqs]==5.3.6

View File

@ -24,6 +24,13 @@ from memory.common.llms.base import (
)
from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.openai_provider import OpenAIProvider
from memory.common.llms.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
TokenAllowance,
UsageBreakdown,
UsageTracker,
)
from memory.common import tokens
__all__ = [
@ -42,6 +49,11 @@ __all__ = [
"StreamEvent",
"LLMSettings",
"create_provider",
"InMemoryUsageTracker",
"RateLimitConfig",
"TokenAllowance",
"UsageBreakdown",
"UsageTracker",
]
logger = logging.getLogger(__name__)

View File

@ -14,6 +14,7 @@ from memory.common.llms.base import (
MessageRole,
StreamEvent,
ToolDefinition,
Usage,
)
logger = logging.getLogger(__name__)
@ -255,6 +256,16 @@ class AnthropicProvider(BaseLLMProvider):
return StreamEvent(type="tool_use", data=tool_data), None
elif event_type == "message_delta":
# Handle token usage information
if usage := getattr(event, "usage", None):
self.log_usage(
Usage(
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
total_tokens=usage.total_tokens,
)
)
delta = getattr(event, "delta", None)
if delta:
stop_reason = getattr(delta, "stop_reason", None)
@ -263,22 +274,6 @@ class AnthropicProvider(BaseLLMProvider):
type="error", data="Max tokens reached"
), current_tool_use
# Handle token usage information
usage = getattr(event, "usage", None)
if usage:
usage_data = {
"input_tokens": getattr(usage, "input_tokens", 0),
"output_tokens": getattr(usage, "output_tokens", 0),
"cache_creation_input_tokens": getattr(
usage, "cache_creation_input_tokens", None
),
"cache_read_input_tokens": getattr(
usage, "cache_read_input_tokens", None
),
}
# Could emit this as a separate event type if needed
logger.debug(f"Token usage: {usage_data}")
return None, current_tool_use
elif event_type == "message_stop":

View File

@ -25,6 +25,15 @@ class MessageRole(str, Enum):
TOOL = "tool"
@dataclass
class Usage:
"""Usage data for an LLM call."""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
@dataclass
class TextContent:
"""Text content in a message."""
@ -219,6 +228,11 @@ class BaseLLMProvider(ABC):
self._client = self._initialize_client()
return self._client
def log_usage(self, usage: Usage):
"""Log usage data."""
logger.debug(f"Token usage: {usage.to_dict()}")
print(f"Token usage: {usage.to_dict()}")
def execute_tool(
self,
tool_call: ToolCall,

View File

@ -16,6 +16,7 @@ from memory.common.llms.base import (
ToolDefinition,
ToolResultContent,
ToolUseContent,
Usage,
)
logger = logging.getLogger(__name__)
@ -283,6 +284,17 @@ class OpenAIProvider(BaseLLMProvider):
"""
events: list[StreamEvent] = []
# Handle usage information (comes in final chunk with empty choices)
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage
self.log_usage(
Usage(
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
)
)
if not chunk.choices:
return events, current_tool_call
@ -337,6 +349,14 @@ class OpenAIProvider(BaseLLMProvider):
try:
response = self.client.chat.completions.create(**kwargs)
usage = response.usage
self.log_usage(
Usage(
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
)
)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
@ -355,6 +375,9 @@ class OpenAIProvider(BaseLLMProvider):
messages, system_prompt, tools, settings, stream=True
)
if kwargs.get("stream"):
kwargs["stream_options"] = {"include_usage": True}
try:
stream = self.client.chat.completions.create(**kwargs)
current_tool_call: dict[str, Any] | None = None

View File

@ -0,0 +1,316 @@
"""LLM usage tracking utilities."""
from collections import defaultdict, deque
from collections.abc import Iterable
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from threading import Lock
@dataclass(frozen=True)
class RateLimitConfig:
"""Configuration for a single rolling usage window."""
window: timedelta
max_input_tokens: int | None = None
max_output_tokens: int | None = None
max_total_tokens: int | None = None
def __post_init__(self) -> None:
if self.window <= timedelta(0):
raise ValueError("window must be positive")
if (
self.max_input_tokens is None
and self.max_output_tokens is None
and self.max_total_tokens is None
):
raise ValueError(
"At least one of max_input_tokens, max_output_tokens or "
"max_total_tokens must be provided"
)
@dataclass
class UsageEvent:
timestamp: datetime
input_tokens: int
output_tokens: int
@dataclass
class UsageState:
events: deque[UsageEvent] = field(default_factory=deque)
window_input_tokens: int = 0
window_output_tokens: int = 0
lifetime_input_tokens: int = 0
lifetime_output_tokens: int = 0
@dataclass
class TokenAllowance:
"""Represents the tokens that can be consumed right now."""
input_tokens: int | None
output_tokens: int | None
total_tokens: int | None
@dataclass
class UsageBreakdown:
"""Detailed usage statistics for a provider/model pair."""
window_input_tokens: int
window_output_tokens: int
window_total_tokens: int
lifetime_input_tokens: int
lifetime_output_tokens: int
@property
def window_total(self) -> int:
return self.window_total_tokens
@property
def lifetime_total_tokens(self) -> int:
return self.lifetime_input_tokens + self.lifetime_output_tokens
def split_model_key(model: str) -> tuple[str, str]:
if "/" not in model:
raise ValueError(
"model must be formatted as '<provider>/<model_name>'"
)
provider, model_name = model.split("/", maxsplit=1)
if not provider or not model_name:
raise ValueError(
"model must include both provider and model name separated by '/'"
)
return provider, model_name
class UsageTracker:
"""Base class for usage trackers that operate on provider/model keys."""
def __init__(
self,
configs: dict[str, RateLimitConfig],
default_config: RateLimitConfig | None = None,
) -> None:
self._configs = configs
self._default_config = default_config
self._lock = Lock()
# ------------------------------------------------------------------
# Storage hooks
# ------------------------------------------------------------------
def get_state(self, key: str) -> UsageState:
raise NotImplementedError
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
raise NotImplementedError
def save_state(self, key: str, state: UsageState) -> None:
"""Persist the given state back to the underlying store."""
del key, state
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def record_usage(
self,
model: str,
input_tokens: int,
output_tokens: int,
timestamp: datetime | None = None,
) -> None:
"""Record token usage for the given provider/model pair."""
if input_tokens < 0 or output_tokens < 0:
raise ValueError("Token counts must be non-negative")
timestamp = timestamp or datetime.now(timezone.utc)
split_model_key(model)
key = model
with self._lock:
config = self._get_config(key)
state = self.get_state(key)
state.lifetime_input_tokens += input_tokens
state.lifetime_output_tokens += output_tokens
if config is None:
self.save_state(key, state)
return
state.events.append(UsageEvent(timestamp, input_tokens, output_tokens))
state.window_input_tokens += input_tokens
state.window_output_tokens += output_tokens
self._prune_expired_events(state, config, now=timestamp)
self.save_state(key, state)
def is_rate_limited(
self,
model: str,
timestamp: datetime | None = None,
) -> bool:
"""Return True if the pair currently exceeds its limits."""
allowance = self.get_available_tokens(model, timestamp=timestamp)
if allowance is None:
return False
limits = [
allowance.input_tokens,
allowance.output_tokens,
allowance.total_tokens,
]
return any(limit is not None and limit <= 0 for limit in limits)
def get_available_tokens(
self,
model: str,
timestamp: datetime | None = None,
) -> TokenAllowance | None:
"""Return the current token allowance for the provider/model pair.
If there is no configuration for the pair (or a default configuration),
``None`` is returned to indicate that no limits are enforced.
"""
split_model_key(model)
key = model
with self._lock:
config = self._get_config(key)
if config is None:
return None
state = self.get_state(key)
self._prune_expired_events(state, config, now=timestamp)
self.save_state(key, state)
if config.max_total_tokens is None:
total_remaining = None
else:
total_remaining = config.max_total_tokens - (
state.window_input_tokens + state.window_output_tokens
)
if config.max_input_tokens is None:
input_remaining = None
else:
input_remaining = config.max_input_tokens - state.window_input_tokens
if config.max_output_tokens is None:
output_remaining = None
else:
output_remaining = (
config.max_output_tokens - state.window_output_tokens
)
return TokenAllowance(
input_tokens=clamp_non_negative(input_remaining),
output_tokens=clamp_non_negative(output_remaining),
total_tokens=clamp_non_negative(total_remaining),
)
def get_usage_breakdown(
self, provider: str | None = None, model: str | None = None
) -> dict[str, dict[str, UsageBreakdown]]:
"""Return usage statistics grouped by provider and model."""
with self._lock:
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
for key, state in self.iter_state_items():
prov, model_name = split_model_key(key)
if provider and provider != prov:
continue
if model and model != model_name:
continue
window_total = state.window_input_tokens + state.window_output_tokens
breakdown = UsageBreakdown(
window_input_tokens=state.window_input_tokens,
window_output_tokens=state.window_output_tokens,
window_total_tokens=window_total,
lifetime_input_tokens=state.lifetime_input_tokens,
lifetime_output_tokens=state.lifetime_output_tokens,
)
providers[prov][model_name] = breakdown
return providers
def iter_provider_totals(self) -> Iterable[tuple[str, UsageBreakdown]]:
"""Yield aggregated totals for each provider across its models."""
breakdowns = self.get_usage_breakdown()
for provider, models in breakdowns.items():
window_input = sum(b.window_input_tokens for b in models.values())
window_output = sum(b.window_output_tokens for b in models.values())
lifetime_input = sum(b.lifetime_input_tokens for b in models.values())
lifetime_output = sum(b.lifetime_output_tokens for b in models.values())
yield (
provider,
UsageBreakdown(
window_input_tokens=window_input,
window_output_tokens=window_output,
window_total_tokens=window_input + window_output,
lifetime_input_tokens=lifetime_input,
lifetime_output_tokens=lifetime_output,
),
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _get_config(self, key: str) -> RateLimitConfig | None:
return self._configs.get(key) or self._default_config
def _prune_expired_events(
self,
state: UsageState,
config: RateLimitConfig,
now: datetime | None = None,
) -> None:
if not state.events:
return
now = now or datetime.now(timezone.utc)
cutoff = now - config.window
for event in tuple(state.events):
if event.timestamp > cutoff:
break
state.events.popleft()
state.window_input_tokens -= event.input_tokens
state.window_output_tokens -= event.output_tokens
state.window_input_tokens = max(state.window_input_tokens, 0)
state.window_output_tokens = max(state.window_output_tokens, 0)
class InMemoryUsageTracker(UsageTracker):
"""Tracks LLM usage for providers and models within a rolling window."""
def __init__(
self,
configs: dict[str, RateLimitConfig],
default_config: RateLimitConfig | None = None,
) -> None:
super().__init__(configs=configs, default_config=default_config)
self._states: dict[str, UsageState] = {}
def get_state(self, key: str) -> UsageState:
return self._states.setdefault(key, UsageState())
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
return tuple(self._states.items())
def clamp_non_negative(value: int | None) -> int | None:
if value is None:
return None
return 0 if value < 0 else value

View File

@ -34,15 +34,26 @@ DB_URL = os.getenv("DATABASE_URL", make_db_url())
# Broker settings
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "amqp").lower() # amqp or sqs
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "kb")
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", "kb")
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
REDIS_DB = os.getenv("REDIS_DB", "0")
CELERY_BROKER_USER = os.getenv(
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
)
CELERY_BROKER_PASSWORD = os.getenv(
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
)
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
if not CELERY_BROKER_HOST and CELERY_BROKER_TYPE == "amqp":
if not CELERY_BROKER_HOST:
if CELERY_BROKER_TYPE == "amqp":
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
elif CELERY_BROKER_TYPE == "redis":
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
@ -161,7 +172,7 @@ STATIC_DIR = pathlib.Path(
)
# Discord notification settings
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
DISCORD_BOT_ID = int(os.getenv("DISCORD_BOT_ID", "0"))
DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors")
DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity")
DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
@ -169,9 +180,7 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
# Enable Discord notifications if bot token is set
DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)
DISCORD_NOTIFICATIONS_ENABLED = boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True)
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5")
DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10))

View File

@ -18,6 +18,7 @@ from memory.common.db.models import (
DiscordChannel,
DiscordUser,
)
from memory.discord.commands import register_slash_commands
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
logger = logging.getLogger(__name__)
@ -199,6 +200,11 @@ class MessageCollector(commands.Bot):
help_command=None, # Disable default help
)
async def setup_hook(self):
"""Register slash commands when the bot is ready."""
register_slash_commands(self)
async def on_ready(self):
"""Called when bot connects to Discord"""
logger.info(f"Discord collector connected as {self.user}")
@ -207,6 +213,11 @@ class MessageCollector(commands.Bot):
# Sync server and channel metadata
await self.sync_servers_and_channels()
try:
await self.tree.sync()
except Exception as exc: # pragma: no cover - defensive
logger.error("Failed to sync slash commands: %s", exc)
logger.info("Discord message collector ready")
async def on_message(self, message: discord.Message):

View File

@ -0,0 +1,393 @@
"""Lightweight slash-command helpers for the Discord collector."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Literal
import discord
from sqlalchemy.orm import Session
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
ScopeLiteral = Literal["server", "channel", "user"]
class CommandError(Exception):
"""Raised when a user-facing error occurs while handling a command."""
@dataclass(slots=True)
class CommandResponse:
"""Value object returned by handlers."""
content: str
ephemeral: bool = True
@dataclass(slots=True)
class CommandContext:
"""All information a handler needs to fulfil a command."""
session: Session
interaction: discord.Interaction
actor: DiscordUser
scope: ScopeLiteral
target: DiscordServer | DiscordChannel | DiscordUser
display_name: str
CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client) -> None:
"""Register the collector slash commands on the provided bot."""
if getattr(bot, "_memory_commands_registered", False):
return
setattr(bot, "_memory_commands_registered", True)
if not hasattr(bot, "tree"):
raise RuntimeError("Bot instance does not support app commands")
tree = bot.tree
@tree.command(name="memory_prompt", description="Show the current system prompt")
@discord.app_commands.describe(
scope="Which configuration to inspect",
user="Target user when the scope is 'user'",
)
async def prompt_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_prompt,
target_user=user,
)
@tree.command(
name="memory_chattiness",
description="Show or update the chattiness threshold for the target",
)
@discord.app_commands.describe(
scope="Which configuration to inspect",
value="Optional new threshold value between 0 and 100",
user="Target user when the scope is 'user'",
)
async def chattiness_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
value: int | None = None,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_chattiness,
target_user=user,
value=value,
)
@tree.command(
name="memory_ignore",
description="Toggle whether the bot should ignore messages for the target",
)
@discord.app_commands.describe(
scope="Which configuration to modify",
enabled="Optional flag. Leave empty to enable ignoring.",
user="Target user when the scope is 'user'",
)
async def ignore_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
enabled: bool | None = None,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_ignore,
target_user=user,
ignore_enabled=enabled,
)
@tree.command(name="memory_summary", description="Show the stored summary for the target")
@discord.app_commands.describe(
scope="Which configuration to inspect",
user="Target user when the scope is 'user'",
)
async def summary_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_summary,
target_user=user,
)
async def _run_interaction_command(
interaction: discord.Interaction,
*,
scope: ScopeLiteral,
handler: CommandHandler,
target_user: discord.User | None = None,
**handler_kwargs,
) -> None:
"""Shared coroutine used by the registered slash commands."""
try:
with make_session() as session:
response = run_command(
session,
interaction,
scope,
handler=handler,
target_user=target_user,
**handler_kwargs,
)
session.commit()
except CommandError as exc: # pragma: no cover - passthrough
await interaction.response.send_message(str(exc), ephemeral=True)
return
await interaction.response.send_message(
response.content,
ephemeral=response.ephemeral,
)
def run_command(
session: Session,
interaction: discord.Interaction,
scope: ScopeLiteral,
*,
handler: CommandHandler,
target_user: discord.User | None = None,
**handler_kwargs,
) -> CommandResponse:
"""Create a :class:`CommandContext` and execute the handler."""
context = _build_context(session, interaction, scope, target_user)
return handler(context, **handler_kwargs)
def _build_context(
session: Session,
interaction: discord.Interaction,
scope: ScopeLiteral,
target_user: discord.User | None,
) -> CommandContext:
actor = _ensure_user(session, interaction.user)
if scope == "server":
if interaction.guild is None:
raise CommandError("This command can only be used inside a server.")
target = _ensure_server(session, interaction.guild)
display_name = f"server **{target.name}**"
return CommandContext(
session=session,
interaction=interaction,
actor=actor,
scope=scope,
target=target,
display_name=display_name,
)
if scope == "channel":
channel_obj = interaction.channel
if channel_obj is None or not hasattr(channel_obj, "id"):
raise CommandError("Unable to determine channel for this interaction.")
target = _ensure_channel(session, channel_obj, interaction.guild_id)
display_name = f"channel **#{target.name}**"
return CommandContext(
session=session,
interaction=interaction,
actor=actor,
scope=scope,
target=target,
display_name=display_name,
)
if scope == "user":
discord_user = target_user or interaction.user
if discord_user is None:
raise CommandError("A target user is required for this command.")
target = _ensure_user(session, discord_user)
display_name = target.display_name or target.username
return CommandContext(
session=session,
interaction=interaction,
actor=actor,
scope=scope,
target=target,
display_name=f"user **{display_name}**",
)
raise CommandError(f"Unsupported scope '{scope}'.")
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
server = session.get(DiscordServer, guild.id)
if server is None:
server = DiscordServer(
id=guild.id,
name=guild.name or f"Server {guild.id}",
description=getattr(guild, "description", None),
member_count=getattr(guild, "member_count", None),
)
session.add(server)
session.flush()
else:
if guild.name and server.name != guild.name:
server.name = guild.name
description = getattr(guild, "description", None)
if description and server.description != description:
server.description = description
member_count = getattr(guild, "member_count", None)
if member_count is not None:
server.member_count = member_count
return server
def _ensure_channel(
session: Session,
channel: discord.abc.Messageable,
guild_id: int | None,
) -> DiscordChannel:
channel_id = getattr(channel, "id", None)
if channel_id is None:
raise CommandError("Channel is missing an identifier.")
channel_model = session.get(DiscordChannel, channel_id)
if channel_model is None:
channel_model = DiscordChannel(
id=channel_id,
server_id=guild_id,
name=getattr(channel, "name", f"Channel {channel_id}"),
channel_type=_resolve_channel_type(channel),
)
session.add(channel_model)
session.flush()
else:
name = getattr(channel, "name", None)
if name and channel_model.name != name:
channel_model.name = name
return channel_model
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
user = session.get(DiscordUser, discord_user.id)
display_name = getattr(discord_user, "display_name", discord_user.name)
if user is None:
user = DiscordUser(
id=discord_user.id,
username=discord_user.name,
display_name=display_name,
)
session.add(user)
session.flush()
else:
if user.username != discord_user.name:
user.username = discord_user.name
if display_name and user.display_name != display_name:
user.display_name = display_name
return user
def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
if isinstance(channel, discord.DMChannel):
return "dm"
if isinstance(channel, discord.GroupChannel):
return "group_dm"
if isinstance(channel, discord.Thread):
return "thread"
if isinstance(channel, discord.VoiceChannel):
return "voice"
if isinstance(channel, discord.TextChannel):
return "text"
return getattr(getattr(channel, "type", None), "name", "unknown")
def handle_prompt(context: CommandContext) -> CommandResponse:
prompt = getattr(context.target, "system_prompt", None)
if prompt:
return CommandResponse(
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
)
return CommandResponse(
content=f"No prompt configured for {context.display_name}.",
)
def handle_chattiness(
context: CommandContext,
*,
value: int | None,
) -> CommandResponse:
model = context.target
if value is None:
return CommandResponse(
content=(
f"Chattiness threshold for {context.display_name}: "
f"{getattr(model, 'chattiness_threshold', 'not set')}"
)
)
if not 0 <= value <= 100:
raise CommandError("Chattiness threshold must be between 0 and 100.")
setattr(model, "chattiness_threshold", value)
return CommandResponse(
content=(
f"Updated chattiness threshold for {context.display_name} "
f"to {value}."
)
)
def handle_ignore(
context: CommandContext,
*,
ignore_enabled: bool | None,
) -> CommandResponse:
model = context.target
new_value = True if ignore_enabled is None else ignore_enabled
setattr(model, "ignore_messages", new_value)
verb = "now ignoring" if new_value else "no longer ignoring"
return CommandResponse(
content=f"The bot is {verb} messages for {context.display_name}.",
)
def handle_summary(context: CommandContext) -> CommandResponse:
summary = getattr(context.target, "summary", None)
if summary:
return CommandResponse(
content=f"Summary for {context.display_name}:\n\n{summary}",
)
return CommandResponse(
content=f"No summary stored for {context.display_name}.",
)

View File

@ -0,0 +1,147 @@
from datetime import datetime, timedelta, timezone
import pytest
from memory.common.llms.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
UsageTracker,
)
@pytest.fixture
def tracker() -> InMemoryUsageTracker:
config = RateLimitConfig(
window=timedelta(minutes=1),
max_input_tokens=1_000,
max_output_tokens=2_000,
max_total_tokens=2_500,
)
return InMemoryUsageTracker(
{
"anthropic/claude-3": config,
"anthropic/haiku": config,
}
)
@pytest.mark.parametrize(
"window, kwargs",
[
(timedelta(minutes=1), {}),
(timedelta(seconds=0), {"max_total_tokens": 1}),
],
)
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None:
with pytest.raises(ValueError):
RateLimitConfig(window=window, **kwargs)
def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
allowance = tracker.get_available_tokens(
"anthropic/claude-3", timestamp=now
)
assert allowance is not None
assert allowance.input_tokens == 900
assert allowance.output_tokens == 1_800
assert allowance.total_tokens == 2_200
def test_rate_limited_when_over_budget(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
assert tracker.is_rate_limited("anthropic/claude-3", timestamp=now)
def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
later = now + timedelta(minutes=2)
allowance = tracker.get_available_tokens(
"anthropic/claude-3", timestamp=later
)
assert allowance is not None
assert allowance.input_tokens == 1_000
assert allowance.output_tokens == 2_000
assert allowance.total_tokens == 2_500
assert not tracker.is_rate_limited("anthropic/claude-3", timestamp=later)
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
breakdown = tracker.get_usage_breakdown()
assert "anthropic" in breakdown
assert "claude-3" in breakdown["anthropic"]
claude_usage = breakdown["anthropic"]["claude-3"]
assert claude_usage.window_input_tokens == 100
assert claude_usage.window_output_tokens == 200
provider_totals = dict(tracker.iter_provider_totals())
anthropic_totals = provider_totals["anthropic"]
assert anthropic_totals.window_input_tokens == 150
assert anthropic_totals.window_output_tokens == 275
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
filtered = tracker.get_usage_breakdown(provider="anthropic")
assert set(filtered.keys()) == {"anthropic"}
assert set(filtered["anthropic"].keys()) == {"claude-3"}
filtered_model = tracker.get_usage_breakdown(model="gpt-4o")
assert set(filtered_model.keys()) == {"openai"}
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
def test_missing_configuration_records_lifetime_only() -> None:
tracker = InMemoryUsageTracker(configs={})
tracker.record_usage("openai/gpt-4o", 10, 20)
assert tracker.get_available_tokens("openai/gpt-4o") is None
breakdown = tracker.get_usage_breakdown()
usage = breakdown["openai"]["gpt-4o"]
assert usage.window_input_tokens == 0
assert usage.lifetime_input_tokens == 10
def test_default_configuration_is_used() -> None:
default = RateLimitConfig(window=timedelta(minutes=1), max_total_tokens=100)
tracker = InMemoryUsageTracker(configs={}, default_config=default)
tracker.record_usage("anthropic/claude-3", 10, 10)
allowance = tracker.get_available_tokens("anthropic/claude-3")
assert allowance is not None
assert allowance.total_tokens == 80
def test_record_usage_rejects_negative_values(tracker: InMemoryUsageTracker) -> None:
with pytest.raises(ValueError):
tracker.record_usage("anthropic/claude-3", -1, 0)
def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
config = RateLimitConfig(window=timedelta(minutes=1), max_output_tokens=50)
tracker = InMemoryUsageTracker({"openai/gpt-4o": config})
tracker.record_usage("openai/gpt-4o", 0, 50)
assert tracker.is_rate_limited("openai/gpt-4o")
def test_usage_tracker_base_not_instantiable() -> None:
class DummyTracker(UsageTracker):
pass
with pytest.raises(NotImplementedError):
DummyTracker({}).record_usage("provider/model", 1, 1)

View File

@ -493,10 +493,12 @@ async def test_on_ready():
collector.user.name = "TestBot"
collector.guilds = [Mock(), Mock()]
collector.sync_servers_and_channels = AsyncMock()
collector.tree.sync = AsyncMock()
await collector.on_ready()
collector.sync_servers_and_channels.assert_called_once()
collector.tree.sync.assert_awaited()
@pytest.mark.asyncio

View File

@ -0,0 +1,172 @@
import pytest
from unittest.mock import MagicMock
import discord
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
from memory.discord.commands import (
CommandError,
CommandResponse,
run_command,
handle_prompt,
handle_chattiness,
handle_ignore,
handle_summary,
)
class DummyInteraction:
"""Lightweight stand-in for :class:`discord.Interaction` used in tests."""
def __init__(
self,
*,
guild: discord.Guild | None,
channel: discord.abc.Messageable | None,
user: discord.abc.User,
) -> None:
self.guild = guild
self.channel = channel
self.user = user
self.guild_id = getattr(guild, "id", None)
self.channel_id = getattr(channel, "id", None)
@pytest.fixture
def guild() -> discord.Guild:
guild = MagicMock(spec=discord.Guild)
guild.id = 123
guild.name = "Test Guild"
guild.description = "Guild description"
guild.member_count = 42
return guild
@pytest.fixture
def text_channel(guild: discord.Guild) -> discord.TextChannel:
channel = MagicMock(spec=discord.TextChannel)
channel.id = 456
channel.name = "general"
channel.guild = guild
channel.type = discord.ChannelType.text
return channel
@pytest.fixture
def discord_user() -> discord.User:
user = MagicMock(spec=discord.User)
user.id = 789
user.name = "command-user"
user.display_name = "Commander"
return user
@pytest.fixture
def interaction(guild, text_channel, discord_user) -> DummyInteraction:
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
def test_handle_command_prompt_server(db_session, guild, interaction):
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
db_session.add(server)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="server",
handler=handle_prompt,
)
assert isinstance(response, CommandResponse)
assert "Be helpful" in response.content
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
response = run_command(
db_session,
interaction,
scope="channel",
handler=handle_prompt,
)
assert "No prompt" in response.content
channel = db_session.get(DiscordChannel, text_channel.id)
assert channel is not None
assert channel.name == text_channel.name
def test_handle_command_chattiness_show(db_session, interaction, guild):
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
db_session.add(server)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="server",
handler=handle_chattiness,
)
assert str(server.chattiness_threshold) in response.content
def test_handle_command_chattiness_update(db_session, interaction):
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
db_session.add(user_model)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="user",
handler=handle_chattiness,
value=80,
)
db_session.flush()
assert "Updated" in response.content
assert user_model.chattiness_threshold == 80
def test_handle_command_chattiness_invalid_value(db_session, interaction):
with pytest.raises(CommandError):
run_command(
db_session,
interaction,
scope="user",
handler=handle_chattiness,
value=150,
)
def test_handle_command_ignore_toggle(db_session, interaction, guild):
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
db_session.add(channel)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="channel",
handler=handle_ignore,
ignore_enabled=True,
)
db_session.flush()
assert "no longer" not in response.content
assert channel.ignore_messages is True
def test_handle_command_summary_missing(db_session, interaction):
response = run_command(
db_session,
interaction,
scope="user",
handler=handle_summary,
)
assert "No summary" in response.content