mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 17:11:19 +01:00
Compare commits
4 Commits
4fedd8fe04
...
8af07f0dac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8af07f0dac | ||
|
|
c296f3b533 | ||
|
|
07852f9ee7 | ||
|
|
bcb470db9b |
5
AGENTS.md
Normal file
5
AGENTS.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Agent Guidance
|
||||
|
||||
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
||||
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
|
||||
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
||||
@ -19,17 +19,17 @@ secrets:
|
||||
volumes:
|
||||
db_data: {} # Postgres
|
||||
qdrant_data: {} # Qdrant
|
||||
rabbitmq_data: {} # RabbitMQ
|
||||
redis_data: {} # Redis
|
||||
|
||||
# ------------------------------ X-templates ----------------------------
|
||||
x-common-env: &env
|
||||
RABBITMQ_USER: kb
|
||||
RABBITMQ_HOST: rabbitmq
|
||||
REDIS_HOST: redis
|
||||
REDIS_PORT: 6379
|
||||
REDIS_DB: 0
|
||||
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
|
||||
QDRANT_HOST: qdrant
|
||||
DB_HOST: postgres
|
||||
DB_PORT: 5432
|
||||
RABBITMQ_PORT: 5672
|
||||
FILE_STORAGE_DIR: /app/memory_files
|
||||
TZ: "Etc/UTC"
|
||||
|
||||
@ -40,7 +40,7 @@ x-worker-base: &worker-base
|
||||
restart: unless-stopped
|
||||
networks: [ kbnet ]
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
depends_on: [ postgres, rabbitmq, qdrant ]
|
||||
depends_on: [ postgres, redis, qdrant ]
|
||||
env_file: [ .env ]
|
||||
environment: &worker-env
|
||||
<<: *env
|
||||
@ -103,22 +103,21 @@ services:
|
||||
volumes:
|
||||
- ./db:/app/db:ro
|
||||
|
||||
rabbitmq:
|
||||
image: rabbitmq:3.13-management
|
||||
redis:
|
||||
image: redis:7.2-alpine
|
||||
restart: unless-stopped
|
||||
networks: [ kbnet ]
|
||||
environment:
|
||||
<<: *env
|
||||
RABBITMQ_DEFAULT_USER: "kb"
|
||||
RABBITMQ_DEFAULT_PASS: "${CELERY_BROKER_PASSWORD}"
|
||||
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${CELERY_BROKER_PASSWORD}"]
|
||||
volumes:
|
||||
- rabbitmq_data:/var/lib/rabbitmq:rw
|
||||
- redis_data:/data:rw
|
||||
healthcheck:
|
||||
test: [ "CMD", "rabbitmq-diagnostics", "ping" ]
|
||||
test: [ "CMD", "redis-cli", "--pass", "${CELERY_BROKER_PASSWORD}", "ping" ]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
cap_drop: [ ALL ]
|
||||
user: redis
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:v1.14.0
|
||||
@ -148,7 +147,7 @@ services:
|
||||
SESSION_COOKIE_NAME: "${SESSION_COOKIE_NAME:-session_id}"
|
||||
restart: unless-stopped
|
||||
networks: [kbnet]
|
||||
depends_on: [postgres, rabbitmq, qdrant]
|
||||
depends_on: [postgres, redis, qdrant]
|
||||
environment:
|
||||
<<: *env
|
||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||
@ -186,7 +185,7 @@ services:
|
||||
environment:
|
||||
<<: *worker-env
|
||||
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
|
||||
DISCORD_NOTIFICATIONS_ENABLED: true
|
||||
DISCORD_NOTIFICATIONS_ENABLED: ${DISCORD_NOTIFICATIONS_ENABLED:-true}
|
||||
DISCORD_COLLECTOR_ENABLED: true
|
||||
volumes:
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
|
||||
@ -9,4 +9,4 @@ anthropic==0.69.0
|
||||
openai==2.3.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
celery[sqs]==5.3.6
|
||||
celery[redis,sqs]==5.3.6
|
||||
|
||||
@ -24,6 +24,13 @@ from memory.common.llms.base import (
|
||||
)
|
||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||
from memory.common.llms.openai_provider import OpenAIProvider
|
||||
from memory.common.llms.usage_tracker import (
|
||||
InMemoryUsageTracker,
|
||||
RateLimitConfig,
|
||||
TokenAllowance,
|
||||
UsageBreakdown,
|
||||
UsageTracker,
|
||||
)
|
||||
from memory.common import tokens
|
||||
|
||||
__all__ = [
|
||||
@ -42,6 +49,11 @@ __all__ = [
|
||||
"StreamEvent",
|
||||
"LLMSettings",
|
||||
"create_provider",
|
||||
"InMemoryUsageTracker",
|
||||
"RateLimitConfig",
|
||||
"TokenAllowance",
|
||||
"UsageBreakdown",
|
||||
"UsageTracker",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,6 +14,7 @@ from memory.common.llms.base import (
|
||||
MessageRole,
|
||||
StreamEvent,
|
||||
ToolDefinition,
|
||||
Usage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -255,6 +256,16 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
return StreamEvent(type="tool_use", data=tool_data), None
|
||||
|
||||
elif event_type == "message_delta":
|
||||
# Handle token usage information
|
||||
if usage := getattr(event, "usage", None):
|
||||
self.log_usage(
|
||||
Usage(
|
||||
input_tokens=usage.input_tokens,
|
||||
output_tokens=usage.output_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
delta = getattr(event, "delta", None)
|
||||
if delta:
|
||||
stop_reason = getattr(delta, "stop_reason", None)
|
||||
@ -263,22 +274,6 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
type="error", data="Max tokens reached"
|
||||
), current_tool_use
|
||||
|
||||
# Handle token usage information
|
||||
usage = getattr(event, "usage", None)
|
||||
if usage:
|
||||
usage_data = {
|
||||
"input_tokens": getattr(usage, "input_tokens", 0),
|
||||
"output_tokens": getattr(usage, "output_tokens", 0),
|
||||
"cache_creation_input_tokens": getattr(
|
||||
usage, "cache_creation_input_tokens", None
|
||||
),
|
||||
"cache_read_input_tokens": getattr(
|
||||
usage, "cache_read_input_tokens", None
|
||||
),
|
||||
}
|
||||
# Could emit this as a separate event type if needed
|
||||
logger.debug(f"Token usage: {usage_data}")
|
||||
|
||||
return None, current_tool_use
|
||||
|
||||
elif event_type == "message_stop":
|
||||
|
||||
@ -25,6 +25,15 @@ class MessageRole(str, Enum):
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Usage:
|
||||
"""Usage data for an LLM call."""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
"""Text content in a message."""
|
||||
@ -219,6 +228,11 @@ class BaseLLMProvider(ABC):
|
||||
self._client = self._initialize_client()
|
||||
return self._client
|
||||
|
||||
def log_usage(self, usage: Usage):
|
||||
"""Log usage data."""
|
||||
logger.debug(f"Token usage: {usage.to_dict()}")
|
||||
print(f"Token usage: {usage.to_dict()}")
|
||||
|
||||
def execute_tool(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
|
||||
@ -16,6 +16,7 @@ from memory.common.llms.base import (
|
||||
ToolDefinition,
|
||||
ToolResultContent,
|
||||
ToolUseContent,
|
||||
Usage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -283,6 +284,17 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
"""
|
||||
events: list[StreamEvent] = []
|
||||
|
||||
# Handle usage information (comes in final chunk with empty choices)
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = chunk.usage
|
||||
self.log_usage(
|
||||
Usage(
|
||||
input_tokens=usage.prompt_tokens,
|
||||
output_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunk.choices:
|
||||
return events, current_tool_call
|
||||
|
||||
@ -337,6 +349,14 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
usage = response.usage
|
||||
self.log_usage(
|
||||
Usage(
|
||||
input_tokens=usage.prompt_tokens,
|
||||
output_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
)
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
@ -355,6 +375,9 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
messages, system_prompt, tools, settings, stream=True
|
||||
)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
try:
|
||||
stream = self.client.chat.completions.create(**kwargs)
|
||||
current_tool_call: dict[str, Any] | None = None
|
||||
|
||||
316
src/memory/common/llms/usage_tracker.py
Normal file
316
src/memory/common/llms/usage_tracker.py
Normal file
@ -0,0 +1,316 @@
|
||||
"""LLM usage tracking utilities."""
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from threading import Lock
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Configuration for a single rolling usage window."""
|
||||
|
||||
window: timedelta
|
||||
max_input_tokens: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
max_total_tokens: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.window <= timedelta(0):
|
||||
raise ValueError("window must be positive")
|
||||
if (
|
||||
self.max_input_tokens is None
|
||||
and self.max_output_tokens is None
|
||||
and self.max_total_tokens is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of max_input_tokens, max_output_tokens or "
|
||||
"max_total_tokens must be provided"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageEvent:
|
||||
timestamp: datetime
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageState:
|
||||
events: deque[UsageEvent] = field(default_factory=deque)
|
||||
window_input_tokens: int = 0
|
||||
window_output_tokens: int = 0
|
||||
lifetime_input_tokens: int = 0
|
||||
lifetime_output_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenAllowance:
|
||||
"""Represents the tokens that can be consumed right now."""
|
||||
|
||||
input_tokens: int | None
|
||||
output_tokens: int | None
|
||||
total_tokens: int | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageBreakdown:
|
||||
"""Detailed usage statistics for a provider/model pair."""
|
||||
|
||||
window_input_tokens: int
|
||||
window_output_tokens: int
|
||||
window_total_tokens: int
|
||||
lifetime_input_tokens: int
|
||||
lifetime_output_tokens: int
|
||||
|
||||
@property
|
||||
def window_total(self) -> int:
|
||||
return self.window_total_tokens
|
||||
|
||||
@property
|
||||
def lifetime_total_tokens(self) -> int:
|
||||
return self.lifetime_input_tokens + self.lifetime_output_tokens
|
||||
|
||||
|
||||
def split_model_key(model: str) -> tuple[str, str]:
|
||||
if "/" not in model:
|
||||
raise ValueError(
|
||||
"model must be formatted as '<provider>/<model_name>'"
|
||||
)
|
||||
|
||||
provider, model_name = model.split("/", maxsplit=1)
|
||||
if not provider or not model_name:
|
||||
raise ValueError(
|
||||
"model must include both provider and model name separated by '/'"
|
||||
)
|
||||
return provider, model_name
|
||||
|
||||
|
||||
class UsageTracker:
|
||||
"""Base class for usage trackers that operate on provider/model keys."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: dict[str, RateLimitConfig],
|
||||
default_config: RateLimitConfig | None = None,
|
||||
) -> None:
|
||||
self._configs = configs
|
||||
self._default_config = default_config
|
||||
self._lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Storage hooks
|
||||
# ------------------------------------------------------------------
|
||||
def get_state(self, key: str) -> UsageState:
|
||||
raise NotImplementedError
|
||||
|
||||
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state(self, key: str, state: UsageState) -> None:
|
||||
"""Persist the given state back to the underlying store."""
|
||||
del key, state
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def record_usage(
|
||||
self,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
timestamp: datetime | None = None,
|
||||
) -> None:
|
||||
"""Record token usage for the given provider/model pair."""
|
||||
|
||||
if input_tokens < 0 or output_tokens < 0:
|
||||
raise ValueError("Token counts must be non-negative")
|
||||
|
||||
timestamp = timestamp or datetime.now(timezone.utc)
|
||||
split_model_key(model)
|
||||
key = model
|
||||
|
||||
with self._lock:
|
||||
config = self._get_config(key)
|
||||
state = self.get_state(key)
|
||||
|
||||
state.lifetime_input_tokens += input_tokens
|
||||
state.lifetime_output_tokens += output_tokens
|
||||
|
||||
if config is None:
|
||||
self.save_state(key, state)
|
||||
return
|
||||
|
||||
state.events.append(UsageEvent(timestamp, input_tokens, output_tokens))
|
||||
state.window_input_tokens += input_tokens
|
||||
state.window_output_tokens += output_tokens
|
||||
|
||||
self._prune_expired_events(state, config, now=timestamp)
|
||||
self.save_state(key, state)
|
||||
|
||||
def is_rate_limited(
|
||||
self,
|
||||
model: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> bool:
|
||||
"""Return True if the pair currently exceeds its limits."""
|
||||
|
||||
allowance = self.get_available_tokens(model, timestamp=timestamp)
|
||||
if allowance is None:
|
||||
return False
|
||||
|
||||
limits = [
|
||||
allowance.input_tokens,
|
||||
allowance.output_tokens,
|
||||
allowance.total_tokens,
|
||||
]
|
||||
return any(limit is not None and limit <= 0 for limit in limits)
|
||||
|
||||
def get_available_tokens(
|
||||
self,
|
||||
model: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> TokenAllowance | None:
|
||||
"""Return the current token allowance for the provider/model pair.
|
||||
|
||||
If there is no configuration for the pair (or a default configuration),
|
||||
``None`` is returned to indicate that no limits are enforced.
|
||||
"""
|
||||
|
||||
split_model_key(model)
|
||||
key = model
|
||||
with self._lock:
|
||||
config = self._get_config(key)
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
state = self.get_state(key)
|
||||
self._prune_expired_events(state, config, now=timestamp)
|
||||
self.save_state(key, state)
|
||||
|
||||
if config.max_total_tokens is None:
|
||||
total_remaining = None
|
||||
else:
|
||||
total_remaining = config.max_total_tokens - (
|
||||
state.window_input_tokens + state.window_output_tokens
|
||||
)
|
||||
|
||||
if config.max_input_tokens is None:
|
||||
input_remaining = None
|
||||
else:
|
||||
input_remaining = config.max_input_tokens - state.window_input_tokens
|
||||
|
||||
if config.max_output_tokens is None:
|
||||
output_remaining = None
|
||||
else:
|
||||
output_remaining = (
|
||||
config.max_output_tokens - state.window_output_tokens
|
||||
)
|
||||
|
||||
return TokenAllowance(
|
||||
input_tokens=clamp_non_negative(input_remaining),
|
||||
output_tokens=clamp_non_negative(output_remaining),
|
||||
total_tokens=clamp_non_negative(total_remaining),
|
||||
)
|
||||
|
||||
def get_usage_breakdown(
|
||||
self, provider: str | None = None, model: str | None = None
|
||||
) -> dict[str, dict[str, UsageBreakdown]]:
|
||||
"""Return usage statistics grouped by provider and model."""
|
||||
|
||||
with self._lock:
|
||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||
for key, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(key)
|
||||
if provider and provider != prov:
|
||||
continue
|
||||
if model and model != model_name:
|
||||
continue
|
||||
|
||||
window_total = state.window_input_tokens + state.window_output_tokens
|
||||
breakdown = UsageBreakdown(
|
||||
window_input_tokens=state.window_input_tokens,
|
||||
window_output_tokens=state.window_output_tokens,
|
||||
window_total_tokens=window_total,
|
||||
lifetime_input_tokens=state.lifetime_input_tokens,
|
||||
lifetime_output_tokens=state.lifetime_output_tokens,
|
||||
)
|
||||
providers[prov][model_name] = breakdown
|
||||
|
||||
return providers
|
||||
|
||||
def iter_provider_totals(self) -> Iterable[tuple[str, UsageBreakdown]]:
|
||||
"""Yield aggregated totals for each provider across its models."""
|
||||
|
||||
breakdowns = self.get_usage_breakdown()
|
||||
for provider, models in breakdowns.items():
|
||||
window_input = sum(b.window_input_tokens for b in models.values())
|
||||
window_output = sum(b.window_output_tokens for b in models.values())
|
||||
lifetime_input = sum(b.lifetime_input_tokens for b in models.values())
|
||||
lifetime_output = sum(b.lifetime_output_tokens for b in models.values())
|
||||
|
||||
yield (
|
||||
provider,
|
||||
UsageBreakdown(
|
||||
window_input_tokens=window_input,
|
||||
window_output_tokens=window_output,
|
||||
window_total_tokens=window_input + window_output,
|
||||
lifetime_input_tokens=lifetime_input,
|
||||
lifetime_output_tokens=lifetime_output,
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_config(self, key: str) -> RateLimitConfig | None:
|
||||
return self._configs.get(key) or self._default_config
|
||||
|
||||
def _prune_expired_events(
|
||||
self,
|
||||
state: UsageState,
|
||||
config: RateLimitConfig,
|
||||
now: datetime | None = None,
|
||||
) -> None:
|
||||
if not state.events:
|
||||
return
|
||||
|
||||
now = now or datetime.now(timezone.utc)
|
||||
cutoff = now - config.window
|
||||
|
||||
for event in tuple(state.events):
|
||||
if event.timestamp > cutoff:
|
||||
break
|
||||
state.events.popleft()
|
||||
state.window_input_tokens -= event.input_tokens
|
||||
state.window_output_tokens -= event.output_tokens
|
||||
|
||||
state.window_input_tokens = max(state.window_input_tokens, 0)
|
||||
state.window_output_tokens = max(state.window_output_tokens, 0)
|
||||
|
||||
|
||||
class InMemoryUsageTracker(UsageTracker):
|
||||
"""Tracks LLM usage for providers and models within a rolling window."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: dict[str, RateLimitConfig],
|
||||
default_config: RateLimitConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(configs=configs, default_config=default_config)
|
||||
self._states: dict[str, UsageState] = {}
|
||||
|
||||
def get_state(self, key: str) -> UsageState:
|
||||
return self._states.setdefault(key, UsageState())
|
||||
|
||||
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||
return tuple(self._states.items())
|
||||
|
||||
|
||||
def clamp_non_negative(value: int | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return 0 if value < 0 else value
|
||||
|
||||
@ -34,15 +34,26 @@ DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||
|
||||
# Broker settings
|
||||
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "amqp").lower() # amqp or sqs
|
||||
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "kb")
|
||||
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", "kb")
|
||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||
CELERY_BROKER_USER = os.getenv(
|
||||
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
||||
)
|
||||
CELERY_BROKER_PASSWORD = os.getenv(
|
||||
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
|
||||
)
|
||||
|
||||
|
||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
||||
if not CELERY_BROKER_HOST and CELERY_BROKER_TYPE == "amqp":
|
||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
||||
if not CELERY_BROKER_HOST:
|
||||
if CELERY_BROKER_TYPE == "amqp":
|
||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
||||
elif CELERY_BROKER_TYPE == "redis":
|
||||
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
|
||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||
|
||||
@ -161,7 +172,7 @@ STATIC_DIR = pathlib.Path(
|
||||
)
|
||||
|
||||
# Discord notification settings
|
||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||
DISCORD_BOT_ID = int(os.getenv("DISCORD_BOT_ID", "0"))
|
||||
DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||
DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||
DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||
@ -169,9 +180,7 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
|
||||
|
||||
# Enable Discord notifications if bot token is set
|
||||
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||
)
|
||||
DISCORD_NOTIFICATIONS_ENABLED = boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
||||
DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5")
|
||||
DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10))
|
||||
|
||||
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
from memory.discord.commands import register_slash_commands
|
||||
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -199,6 +200,11 @@ class MessageCollector(commands.Bot):
|
||||
help_command=None, # Disable default help
|
||||
)
|
||||
|
||||
async def setup_hook(self):
|
||||
"""Register slash commands when the bot is ready."""
|
||||
|
||||
register_slash_commands(self)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Called when bot connects to Discord"""
|
||||
logger.info(f"Discord collector connected as {self.user}")
|
||||
@ -207,6 +213,11 @@ class MessageCollector(commands.Bot):
|
||||
# Sync server and channel metadata
|
||||
await self.sync_servers_and_channels()
|
||||
|
||||
try:
|
||||
await self.tree.sync()
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.error("Failed to sync slash commands: %s", exc)
|
||||
|
||||
logger.info("Discord message collector ready")
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
|
||||
393
src/memory/discord/commands.py
Normal file
393
src/memory/discord/commands.py
Normal file
@ -0,0 +1,393 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal
|
||||
|
||||
import discord
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
|
||||
ScopeLiteral = Literal["server", "channel", "user"]
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""Raised when a user-facing error occurs while handling a command."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CommandResponse:
|
||||
"""Value object returned by handlers."""
|
||||
|
||||
content: str
|
||||
ephemeral: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CommandContext:
|
||||
"""All information a handler needs to fulfil a command."""
|
||||
|
||||
session: Session
|
||||
interaction: discord.Interaction
|
||||
actor: DiscordUser
|
||||
scope: ScopeLiteral
|
||||
target: DiscordServer | DiscordChannel | DiscordUser
|
||||
display_name: str
|
||||
|
||||
|
||||
CommandHandler = Callable[..., CommandResponse]
|
||||
|
||||
|
||||
def register_slash_commands(bot: discord.Client) -> None:
|
||||
"""Register the collector slash commands on the provided bot."""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
return
|
||||
|
||||
setattr(bot, "_memory_commands_registered", True)
|
||||
|
||||
if not hasattr(bot, "tree"):
|
||||
raise RuntimeError("Bot instance does not support app commands")
|
||||
|
||||
tree = bot.tree
|
||||
|
||||
@tree.command(name="memory_prompt", description="Show the current system prompt")
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_prompt,
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_chattiness",
|
||||
description="Show or update the chattiness threshold for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
value="Optional new threshold value between 0 and 100",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def chattiness_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
value: int | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_chattiness,
|
||||
target_user=user,
|
||||
value=value,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_ignore",
|
||||
description="Toggle whether the bot should ignore messages for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
enabled="Optional flag. Leave empty to enable ignoring.",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def ignore_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
enabled: bool | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_ignore,
|
||||
target_user=user,
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
@tree.command(name="memory_summary", description="Show the stored summary for the target")
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def summary_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_summary,
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
|
||||
async def _run_interaction_command(
|
||||
interaction: discord.Interaction,
|
||||
*,
|
||||
scope: ScopeLiteral,
|
||||
handler: CommandHandler,
|
||||
target_user: discord.User | None = None,
|
||||
**handler_kwargs,
|
||||
) -> None:
|
||||
"""Shared coroutine used by the registered slash commands."""
|
||||
|
||||
try:
|
||||
with make_session() as session:
|
||||
response = run_command(
|
||||
session,
|
||||
interaction,
|
||||
scope,
|
||||
handler=handler,
|
||||
target_user=target_user,
|
||||
**handler_kwargs,
|
||||
)
|
||||
session.commit()
|
||||
except CommandError as exc: # pragma: no cover - passthrough
|
||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||
return
|
||||
|
||||
await interaction.response.send_message(
|
||||
response.content,
|
||||
ephemeral=response.ephemeral,
|
||||
)
|
||||
|
||||
|
||||
def run_command(
|
||||
session: Session,
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
*,
|
||||
handler: CommandHandler,
|
||||
target_user: discord.User | None = None,
|
||||
**handler_kwargs,
|
||||
) -> CommandResponse:
|
||||
"""Create a :class:`CommandContext` and execute the handler."""
|
||||
|
||||
context = _build_context(session, interaction, scope, target_user)
|
||||
return handler(context, **handler_kwargs)
|
||||
|
||||
|
||||
def _build_context(
|
||||
session: Session,
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
target_user: discord.User | None,
|
||||
) -> CommandContext:
|
||||
actor = _ensure_user(session, interaction.user)
|
||||
|
||||
if scope == "server":
|
||||
if interaction.guild is None:
|
||||
raise CommandError("This command can only be used inside a server.")
|
||||
|
||||
target = _ensure_server(session, interaction.guild)
|
||||
display_name = f"server **{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
if scope == "channel":
|
||||
channel_obj = interaction.channel
|
||||
if channel_obj is None or not hasattr(channel_obj, "id"):
|
||||
raise CommandError("Unable to determine channel for this interaction.")
|
||||
|
||||
target = _ensure_channel(session, channel_obj, interaction.guild_id)
|
||||
display_name = f"channel **#{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
if scope == "user":
|
||||
discord_user = target_user or interaction.user
|
||||
if discord_user is None:
|
||||
raise CommandError("A target user is required for this command.")
|
||||
|
||||
target = _ensure_user(session, discord_user)
|
||||
display_name = target.display_name or target.username
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=f"user **{display_name}**",
|
||||
)
|
||||
|
||||
raise CommandError(f"Unsupported scope '{scope}'.")
|
||||
|
||||
|
||||
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||
server = session.get(DiscordServer, guild.id)
|
||||
if server is None:
|
||||
server = DiscordServer(
|
||||
id=guild.id,
|
||||
name=guild.name or f"Server {guild.id}",
|
||||
description=getattr(guild, "description", None),
|
||||
member_count=getattr(guild, "member_count", None),
|
||||
)
|
||||
session.add(server)
|
||||
session.flush()
|
||||
else:
|
||||
if guild.name and server.name != guild.name:
|
||||
server.name = guild.name
|
||||
description = getattr(guild, "description", None)
|
||||
if description and server.description != description:
|
||||
server.description = description
|
||||
member_count = getattr(guild, "member_count", None)
|
||||
if member_count is not None:
|
||||
server.member_count = member_count
|
||||
|
||||
return server
|
||||
|
||||
|
||||
def _ensure_channel(
|
||||
session: Session,
|
||||
channel: discord.abc.Messageable,
|
||||
guild_id: int | None,
|
||||
) -> DiscordChannel:
|
||||
channel_id = getattr(channel, "id", None)
|
||||
if channel_id is None:
|
||||
raise CommandError("Channel is missing an identifier.")
|
||||
|
||||
channel_model = session.get(DiscordChannel, channel_id)
|
||||
if channel_model is None:
|
||||
channel_model = DiscordChannel(
|
||||
id=channel_id,
|
||||
server_id=guild_id,
|
||||
name=getattr(channel, "name", f"Channel {channel_id}"),
|
||||
channel_type=_resolve_channel_type(channel),
|
||||
)
|
||||
session.add(channel_model)
|
||||
session.flush()
|
||||
else:
|
||||
name = getattr(channel, "name", None)
|
||||
if name and channel_model.name != name:
|
||||
channel_model.name = name
|
||||
|
||||
return channel_model
|
||||
|
||||
|
||||
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||
user = session.get(DiscordUser, discord_user.id)
|
||||
display_name = getattr(discord_user, "display_name", discord_user.name)
|
||||
if user is None:
|
||||
user = DiscordUser(
|
||||
id=discord_user.id,
|
||||
username=discord_user.name,
|
||||
display_name=display_name,
|
||||
)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
else:
|
||||
if user.username != discord_user.name:
|
||||
user.username = discord_user.name
|
||||
if display_name and user.display_name != display_name:
|
||||
user.display_name = display_name
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
return "dm"
|
||||
if isinstance(channel, discord.GroupChannel):
|
||||
return "group_dm"
|
||||
if isinstance(channel, discord.Thread):
|
||||
return "thread"
|
||||
if isinstance(channel, discord.VoiceChannel):
|
||||
return "voice"
|
||||
if isinstance(channel, discord.TextChannel):
|
||||
return "text"
|
||||
return getattr(getattr(channel, "type", None), "name", "unknown")
|
||||
|
||||
|
||||
def handle_prompt(context: CommandContext) -> CommandResponse:
|
||||
prompt = getattr(context.target, "system_prompt", None)
|
||||
|
||||
if prompt:
|
||||
return CommandResponse(
|
||||
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
|
||||
)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"No prompt configured for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_chattiness(
|
||||
context: CommandContext,
|
||||
*,
|
||||
value: int | None,
|
||||
) -> CommandResponse:
|
||||
model = context.target
|
||||
|
||||
if value is None:
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Chattiness threshold for {context.display_name}: "
|
||||
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
||||
)
|
||||
)
|
||||
|
||||
if not 0 <= value <= 100:
|
||||
raise CommandError("Chattiness threshold must be between 0 and 100.")
|
||||
|
||||
setattr(model, "chattiness_threshold", value)
|
||||
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Updated chattiness threshold for {context.display_name} "
|
||||
f"to {value}."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def handle_ignore(
|
||||
context: CommandContext,
|
||||
*,
|
||||
ignore_enabled: bool | None,
|
||||
) -> CommandResponse:
|
||||
model = context.target
|
||||
new_value = True if ignore_enabled is None else ignore_enabled
|
||||
setattr(model, "ignore_messages", new_value)
|
||||
|
||||
verb = "now ignoring" if new_value else "no longer ignoring"
|
||||
return CommandResponse(
|
||||
content=f"The bot is {verb} messages for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
summary = getattr(context.target, "summary", None)
|
||||
|
||||
if summary:
|
||||
return CommandResponse(
|
||||
content=f"Summary for {context.display_name}:\n\n{summary}",
|
||||
)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"No summary stored for {context.display_name}.",
|
||||
)
|
||||
147
tests/memory/common/llms/test_usage_tracker.py
Normal file
147
tests/memory/common/llms/test_usage_tracker.py
Normal file
@ -0,0 +1,147 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from memory.common.llms.usage_tracker import (
|
||||
InMemoryUsageTracker,
|
||||
RateLimitConfig,
|
||||
UsageTracker,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker() -> InMemoryUsageTracker:
|
||||
config = RateLimitConfig(
|
||||
window=timedelta(minutes=1),
|
||||
max_input_tokens=1_000,
|
||||
max_output_tokens=2_000,
|
||||
max_total_tokens=2_500,
|
||||
)
|
||||
return InMemoryUsageTracker(
|
||||
{
|
||||
"anthropic/claude-3": config,
|
||||
"anthropic/haiku": config,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"window, kwargs",
|
||||
[
|
||||
(timedelta(minutes=1), {}),
|
||||
(timedelta(seconds=0), {"max_total_tokens": 1}),
|
||||
],
|
||||
)
|
||||
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
RateLimitConfig(window=window, **kwargs)
|
||||
|
||||
|
||||
def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
|
||||
allowance = tracker.get_available_tokens(
|
||||
"anthropic/claude-3", timestamp=now
|
||||
)
|
||||
assert allowance is not None
|
||||
assert allowance.input_tokens == 900
|
||||
assert allowance.output_tokens == 1_800
|
||||
assert allowance.total_tokens == 2_200
|
||||
|
||||
|
||||
def test_rate_limited_when_over_budget(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||
|
||||
assert tracker.is_rate_limited("anthropic/claude-3", timestamp=now)
|
||||
|
||||
|
||||
def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||
|
||||
later = now + timedelta(minutes=2)
|
||||
allowance = tracker.get_available_tokens(
|
||||
"anthropic/claude-3", timestamp=later
|
||||
)
|
||||
assert allowance is not None
|
||||
assert allowance.input_tokens == 1_000
|
||||
assert allowance.output_tokens == 2_000
|
||||
assert allowance.total_tokens == 2_500
|
||||
assert not tracker.is_rate_limited("anthropic/claude-3", timestamp=later)
|
||||
|
||||
|
||||
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||
|
||||
breakdown = tracker.get_usage_breakdown()
|
||||
assert "anthropic" in breakdown
|
||||
assert "claude-3" in breakdown["anthropic"]
|
||||
claude_usage = breakdown["anthropic"]["claude-3"]
|
||||
assert claude_usage.window_input_tokens == 100
|
||||
assert claude_usage.window_output_tokens == 200
|
||||
|
||||
provider_totals = dict(tracker.iter_provider_totals())
|
||||
anthropic_totals = provider_totals["anthropic"]
|
||||
assert anthropic_totals.window_input_tokens == 150
|
||||
assert anthropic_totals.window_output_tokens == 275
|
||||
|
||||
|
||||
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
|
||||
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
|
||||
|
||||
filtered = tracker.get_usage_breakdown(provider="anthropic")
|
||||
assert set(filtered.keys()) == {"anthropic"}
|
||||
assert set(filtered["anthropic"].keys()) == {"claude-3"}
|
||||
|
||||
filtered_model = tracker.get_usage_breakdown(model="gpt-4o")
|
||||
assert set(filtered_model.keys()) == {"openai"}
|
||||
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
|
||||
|
||||
|
||||
def test_missing_configuration_records_lifetime_only() -> None:
|
||||
tracker = InMemoryUsageTracker(configs={})
|
||||
tracker.record_usage("openai/gpt-4o", 10, 20)
|
||||
|
||||
assert tracker.get_available_tokens("openai/gpt-4o") is None
|
||||
|
||||
breakdown = tracker.get_usage_breakdown()
|
||||
usage = breakdown["openai"]["gpt-4o"]
|
||||
assert usage.window_input_tokens == 0
|
||||
assert usage.lifetime_input_tokens == 10
|
||||
|
||||
|
||||
def test_default_configuration_is_used() -> None:
|
||||
default = RateLimitConfig(window=timedelta(minutes=1), max_total_tokens=100)
|
||||
tracker = InMemoryUsageTracker(configs={}, default_config=default)
|
||||
|
||||
tracker.record_usage("anthropic/claude-3", 10, 10)
|
||||
allowance = tracker.get_available_tokens("anthropic/claude-3")
|
||||
assert allowance is not None
|
||||
assert allowance.total_tokens == 80
|
||||
|
||||
|
||||
def test_record_usage_rejects_negative_values(tracker: InMemoryUsageTracker) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
tracker.record_usage("anthropic/claude-3", -1, 0)
|
||||
|
||||
|
||||
def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
||||
config = RateLimitConfig(window=timedelta(minutes=1), max_output_tokens=50)
|
||||
tracker = InMemoryUsageTracker({"openai/gpt-4o": config})
|
||||
|
||||
tracker.record_usage("openai/gpt-4o", 0, 50)
|
||||
assert tracker.is_rate_limited("openai/gpt-4o")
|
||||
|
||||
|
||||
def test_usage_tracker_base_not_instantiable() -> None:
|
||||
class DummyTracker(UsageTracker):
|
||||
pass
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
DummyTracker({}).record_usage("provider/model", 1, 1)
|
||||
@ -493,10 +493,12 @@ async def test_on_ready():
|
||||
collector.user.name = "TestBot"
|
||||
collector.guilds = [Mock(), Mock()]
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
collector.tree.sync = AsyncMock()
|
||||
|
||||
await collector.on_ready()
|
||||
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
collector.tree.sync.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
172
tests/memory/discord_tests/test_commands.py
Normal file
172
tests/memory/discord_tests/test_commands.py
Normal file
@ -0,0 +1,172 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import discord
|
||||
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.commands import (
|
||||
CommandError,
|
||||
CommandResponse,
|
||||
run_command,
|
||||
handle_prompt,
|
||||
handle_chattiness,
|
||||
handle_ignore,
|
||||
handle_summary,
|
||||
)
|
||||
|
||||
|
||||
class DummyInteraction:
|
||||
"""Lightweight stand-in for :class:`discord.Interaction` used in tests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
guild: discord.Guild | None,
|
||||
channel: discord.abc.Messageable | None,
|
||||
user: discord.abc.User,
|
||||
) -> None:
|
||||
self.guild = guild
|
||||
self.channel = channel
|
||||
self.user = user
|
||||
self.guild_id = getattr(guild, "id", None)
|
||||
self.channel_id = getattr(channel, "id", None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def guild() -> discord.Guild:
|
||||
guild = MagicMock(spec=discord.Guild)
|
||||
guild.id = 123
|
||||
guild.name = "Test Guild"
|
||||
guild.description = "Guild description"
|
||||
guild.member_count = 42
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_channel(guild: discord.Guild) -> discord.TextChannel:
|
||||
channel = MagicMock(spec=discord.TextChannel)
|
||||
channel.id = 456
|
||||
channel.name = "general"
|
||||
channel.guild = guild
|
||||
channel.type = discord.ChannelType.text
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discord_user() -> discord.User:
|
||||
user = MagicMock(spec=discord.User)
|
||||
user.id = 789
|
||||
user.name = "command-user"
|
||||
user.display_name = "Commander"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def interaction(guild, text_channel, discord_user) -> DummyInteraction:
|
||||
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
|
||||
|
||||
|
||||
def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="server",
|
||||
handler=handle_prompt,
|
||||
)
|
||||
|
||||
assert isinstance(response, CommandResponse)
|
||||
assert "Be helpful" in response.content
|
||||
|
||||
|
||||
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="channel",
|
||||
handler=handle_prompt,
|
||||
)
|
||||
|
||||
assert "No prompt" in response.content
|
||||
channel = db_session.get(DiscordChannel, text_channel.id)
|
||||
assert channel is not None
|
||||
assert channel.name == text_channel.name
|
||||
|
||||
|
||||
def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="server",
|
||||
handler=handle_chattiness,
|
||||
)
|
||||
|
||||
assert str(server.chattiness_threshold) in response.content
|
||||
|
||||
|
||||
def test_handle_command_chattiness_update(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_chattiness,
|
||||
value=80,
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert user_model.chattiness_threshold == 80
|
||||
|
||||
|
||||
def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||
with pytest.raises(CommandError):
|
||||
run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_chattiness,
|
||||
value=150,
|
||||
)
|
||||
|
||||
|
||||
def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="channel",
|
||||
handler=handle_ignore,
|
||||
ignore_enabled=True,
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "no longer" not in response.content
|
||||
assert channel.ignore_messages is True
|
||||
|
||||
|
||||
def test_handle_command_summary_missing(db_session, interaction):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_summary,
|
||||
)
|
||||
|
||||
assert "No summary" in response.content
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user