mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 09:01:17 +01:00
Compare commits
4 Commits
4fedd8fe04
...
8af07f0dac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8af07f0dac | ||
|
|
c296f3b533 | ||
|
|
07852f9ee7 | ||
|
|
bcb470db9b |
5
AGENTS.md
Normal file
5
AGENTS.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Agent Guidance
|
||||||
|
|
||||||
|
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
||||||
|
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
|
||||||
|
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
||||||
@ -19,17 +19,17 @@ secrets:
|
|||||||
volumes:
|
volumes:
|
||||||
db_data: {} # Postgres
|
db_data: {} # Postgres
|
||||||
qdrant_data: {} # Qdrant
|
qdrant_data: {} # Qdrant
|
||||||
rabbitmq_data: {} # RabbitMQ
|
redis_data: {} # Redis
|
||||||
|
|
||||||
# ------------------------------ X-templates ----------------------------
|
# ------------------------------ X-templates ----------------------------
|
||||||
x-common-env: &env
|
x-common-env: &env
|
||||||
RABBITMQ_USER: kb
|
REDIS_HOST: redis
|
||||||
RABBITMQ_HOST: rabbitmq
|
REDIS_PORT: 6379
|
||||||
|
REDIS_DB: 0
|
||||||
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
|
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
|
||||||
QDRANT_HOST: qdrant
|
QDRANT_HOST: qdrant
|
||||||
DB_HOST: postgres
|
DB_HOST: postgres
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
RABBITMQ_PORT: 5672
|
|
||||||
FILE_STORAGE_DIR: /app/memory_files
|
FILE_STORAGE_DIR: /app/memory_files
|
||||||
TZ: "Etc/UTC"
|
TZ: "Etc/UTC"
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ x-worker-base: &worker-base
|
|||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [ kbnet ]
|
networks: [ kbnet ]
|
||||||
security_opt: [ "no-new-privileges=true" ]
|
security_opt: [ "no-new-privileges=true" ]
|
||||||
depends_on: [ postgres, rabbitmq, qdrant ]
|
depends_on: [ postgres, redis, qdrant ]
|
||||||
env_file: [ .env ]
|
env_file: [ .env ]
|
||||||
environment: &worker-env
|
environment: &worker-env
|
||||||
<<: *env
|
<<: *env
|
||||||
@ -103,22 +103,21 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./db:/app/db:ro
|
- ./db:/app/db:ro
|
||||||
|
|
||||||
rabbitmq:
|
redis:
|
||||||
image: rabbitmq:3.13-management
|
image: redis:7.2-alpine
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [ kbnet ]
|
networks: [ kbnet ]
|
||||||
environment:
|
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${CELERY_BROKER_PASSWORD}"]
|
||||||
<<: *env
|
|
||||||
RABBITMQ_DEFAULT_USER: "kb"
|
|
||||||
RABBITMQ_DEFAULT_PASS: "${CELERY_BROKER_PASSWORD}"
|
|
||||||
volumes:
|
volumes:
|
||||||
- rabbitmq_data:/var/lib/rabbitmq:rw
|
- redis_data:/data:rw
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: [ "CMD", "rabbitmq-diagnostics", "ping" ]
|
test: [ "CMD", "redis-cli", "--pass", "${CELERY_BROKER_PASSWORD}", "ping" ]
|
||||||
interval: 15s
|
interval: 15s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
security_opt: [ "no-new-privileges=true" ]
|
security_opt: [ "no-new-privileges=true" ]
|
||||||
|
cap_drop: [ ALL ]
|
||||||
|
user: redis
|
||||||
|
|
||||||
qdrant:
|
qdrant:
|
||||||
image: qdrant/qdrant:v1.14.0
|
image: qdrant/qdrant:v1.14.0
|
||||||
@ -148,7 +147,7 @@ services:
|
|||||||
SESSION_COOKIE_NAME: "${SESSION_COOKIE_NAME:-session_id}"
|
SESSION_COOKIE_NAME: "${SESSION_COOKIE_NAME:-session_id}"
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [kbnet]
|
networks: [kbnet]
|
||||||
depends_on: [postgres, rabbitmq, qdrant]
|
depends_on: [postgres, redis, qdrant]
|
||||||
environment:
|
environment:
|
||||||
<<: *env
|
<<: *env
|
||||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||||
@ -186,7 +185,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
|
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
|
||||||
DISCORD_NOTIFICATIONS_ENABLED: true
|
DISCORD_NOTIFICATIONS_ENABLED: ${DISCORD_NOTIFICATIONS_ENABLED:-true}
|
||||||
DISCORD_COLLECTOR_ENABLED: true
|
DISCORD_COLLECTOR_ENABLED: true
|
||||||
volumes:
|
volumes:
|
||||||
- ./memory_files:/app/memory_files:rw
|
- ./memory_files:/app/memory_files:rw
|
||||||
|
|||||||
@ -9,4 +9,4 @@ anthropic==0.69.0
|
|||||||
openai==2.3.0
|
openai==2.3.0
|
||||||
# Pin the httpx version, as newer versions break the anthropic client
|
# Pin the httpx version, as newer versions break the anthropic client
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
celery[sqs]==5.3.6
|
celery[redis,sqs]==5.3.6
|
||||||
|
|||||||
@ -24,6 +24,13 @@ from memory.common.llms.base import (
|
|||||||
)
|
)
|
||||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
from memory.common.llms.openai_provider import OpenAIProvider
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
|
from memory.common.llms.usage_tracker import (
|
||||||
|
InMemoryUsageTracker,
|
||||||
|
RateLimitConfig,
|
||||||
|
TokenAllowance,
|
||||||
|
UsageBreakdown,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
from memory.common import tokens
|
from memory.common import tokens
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -42,6 +49,11 @@ __all__ = [
|
|||||||
"StreamEvent",
|
"StreamEvent",
|
||||||
"LLMSettings",
|
"LLMSettings",
|
||||||
"create_provider",
|
"create_provider",
|
||||||
|
"InMemoryUsageTracker",
|
||||||
|
"RateLimitConfig",
|
||||||
|
"TokenAllowance",
|
||||||
|
"UsageBreakdown",
|
||||||
|
"UsageTracker",
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from memory.common.llms.base import (
|
|||||||
MessageRole,
|
MessageRole,
|
||||||
StreamEvent,
|
StreamEvent,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
Usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -255,6 +256,16 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
return StreamEvent(type="tool_use", data=tool_data), None
|
return StreamEvent(type="tool_use", data=tool_data), None
|
||||||
|
|
||||||
elif event_type == "message_delta":
|
elif event_type == "message_delta":
|
||||||
|
# Handle token usage information
|
||||||
|
if usage := getattr(event, "usage", None):
|
||||||
|
self.log_usage(
|
||||||
|
Usage(
|
||||||
|
input_tokens=usage.input_tokens,
|
||||||
|
output_tokens=usage.output_tokens,
|
||||||
|
total_tokens=usage.total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
delta = getattr(event, "delta", None)
|
delta = getattr(event, "delta", None)
|
||||||
if delta:
|
if delta:
|
||||||
stop_reason = getattr(delta, "stop_reason", None)
|
stop_reason = getattr(delta, "stop_reason", None)
|
||||||
@ -263,22 +274,6 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
type="error", data="Max tokens reached"
|
type="error", data="Max tokens reached"
|
||||||
), current_tool_use
|
), current_tool_use
|
||||||
|
|
||||||
# Handle token usage information
|
|
||||||
usage = getattr(event, "usage", None)
|
|
||||||
if usage:
|
|
||||||
usage_data = {
|
|
||||||
"input_tokens": getattr(usage, "input_tokens", 0),
|
|
||||||
"output_tokens": getattr(usage, "output_tokens", 0),
|
|
||||||
"cache_creation_input_tokens": getattr(
|
|
||||||
usage, "cache_creation_input_tokens", None
|
|
||||||
),
|
|
||||||
"cache_read_input_tokens": getattr(
|
|
||||||
usage, "cache_read_input_tokens", None
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# Could emit this as a separate event type if needed
|
|
||||||
logger.debug(f"Token usage: {usage_data}")
|
|
||||||
|
|
||||||
return None, current_tool_use
|
return None, current_tool_use
|
||||||
|
|
||||||
elif event_type == "message_stop":
|
elif event_type == "message_stop":
|
||||||
|
|||||||
@ -25,6 +25,15 @@ class MessageRole(str, Enum):
|
|||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Usage:
|
||||||
|
"""Usage data for an LLM call."""
|
||||||
|
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextContent:
|
class TextContent:
|
||||||
"""Text content in a message."""
|
"""Text content in a message."""
|
||||||
@ -219,6 +228,11 @@ class BaseLLMProvider(ABC):
|
|||||||
self._client = self._initialize_client()
|
self._client = self._initialize_client()
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
def log_usage(self, usage: Usage):
|
||||||
|
"""Log usage data."""
|
||||||
|
logger.debug(f"Token usage: {usage.to_dict()}")
|
||||||
|
print(f"Token usage: {usage.to_dict()}")
|
||||||
|
|
||||||
def execute_tool(
|
def execute_tool(
|
||||||
self,
|
self,
|
||||||
tool_call: ToolCall,
|
tool_call: ToolCall,
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from memory.common.llms.base import (
|
|||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResultContent,
|
ToolResultContent,
|
||||||
ToolUseContent,
|
ToolUseContent,
|
||||||
|
Usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -283,6 +284,17 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
"""
|
"""
|
||||||
events: list[StreamEvent] = []
|
events: list[StreamEvent] = []
|
||||||
|
|
||||||
|
# Handle usage information (comes in final chunk with empty choices)
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
self.log_usage(
|
||||||
|
Usage(
|
||||||
|
input_tokens=usage.prompt_tokens,
|
||||||
|
output_tokens=usage.completion_tokens,
|
||||||
|
total_tokens=usage.total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not chunk.choices:
|
if not chunk.choices:
|
||||||
return events, current_tool_call
|
return events, current_tool_call
|
||||||
|
|
||||||
@ -337,6 +349,14 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(**kwargs)
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
|
usage = response.usage
|
||||||
|
self.log_usage(
|
||||||
|
Usage(
|
||||||
|
input_tokens=usage.prompt_tokens,
|
||||||
|
output_tokens=usage.completion_tokens,
|
||||||
|
total_tokens=usage.total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
return response.choices[0].message.content or ""
|
return response.choices[0].message.content or ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OpenAI API error: {e}")
|
logger.error(f"OpenAI API error: {e}")
|
||||||
@ -355,6 +375,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages, system_prompt, tools, settings, stream=True
|
messages, system_prompt, tools, settings, stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kwargs.get("stream"):
|
||||||
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = self.client.chat.completions.create(**kwargs)
|
stream = self.client.chat.completions.create(**kwargs)
|
||||||
current_tool_call: dict[str, Any] | None = None
|
current_tool_call: dict[str, Any] | None = None
|
||||||
|
|||||||
316
src/memory/common/llms/usage_tracker.py
Normal file
316
src/memory/common/llms/usage_tracker.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
"""LLM usage tracking utilities."""
|
||||||
|
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""Configuration for a single rolling usage window."""
|
||||||
|
|
||||||
|
window: timedelta
|
||||||
|
max_input_tokens: int | None = None
|
||||||
|
max_output_tokens: int | None = None
|
||||||
|
max_total_tokens: int | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.window <= timedelta(0):
|
||||||
|
raise ValueError("window must be positive")
|
||||||
|
if (
|
||||||
|
self.max_input_tokens is None
|
||||||
|
and self.max_output_tokens is None
|
||||||
|
and self.max_total_tokens is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of max_input_tokens, max_output_tokens or "
|
||||||
|
"max_total_tokens must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageEvent:
|
||||||
|
timestamp: datetime
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageState:
|
||||||
|
events: deque[UsageEvent] = field(default_factory=deque)
|
||||||
|
window_input_tokens: int = 0
|
||||||
|
window_output_tokens: int = 0
|
||||||
|
lifetime_input_tokens: int = 0
|
||||||
|
lifetime_output_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenAllowance:
|
||||||
|
"""Represents the tokens that can be consumed right now."""
|
||||||
|
|
||||||
|
input_tokens: int | None
|
||||||
|
output_tokens: int | None
|
||||||
|
total_tokens: int | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageBreakdown:
|
||||||
|
"""Detailed usage statistics for a provider/model pair."""
|
||||||
|
|
||||||
|
window_input_tokens: int
|
||||||
|
window_output_tokens: int
|
||||||
|
window_total_tokens: int
|
||||||
|
lifetime_input_tokens: int
|
||||||
|
lifetime_output_tokens: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def window_total(self) -> int:
|
||||||
|
return self.window_total_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lifetime_total_tokens(self) -> int:
|
||||||
|
return self.lifetime_input_tokens + self.lifetime_output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def split_model_key(model: str) -> tuple[str, str]:
|
||||||
|
if "/" not in model:
|
||||||
|
raise ValueError(
|
||||||
|
"model must be formatted as '<provider>/<model_name>'"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider, model_name = model.split("/", maxsplit=1)
|
||||||
|
if not provider or not model_name:
|
||||||
|
raise ValueError(
|
||||||
|
"model must include both provider and model name separated by '/'"
|
||||||
|
)
|
||||||
|
return provider, model_name
|
||||||
|
|
||||||
|
|
||||||
|
class UsageTracker:
|
||||||
|
"""Base class for usage trackers that operate on provider/model keys."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: dict[str, RateLimitConfig],
|
||||||
|
default_config: RateLimitConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._configs = configs
|
||||||
|
self._default_config = default_config
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Storage hooks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def get_state(self, key: str) -> UsageState:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save_state(self, key: str, state: UsageState) -> None:
|
||||||
|
"""Persist the given state back to the underlying store."""
|
||||||
|
del key, state
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def record_usage(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Record token usage for the given provider/model pair."""
|
||||||
|
|
||||||
|
if input_tokens < 0 or output_tokens < 0:
|
||||||
|
raise ValueError("Token counts must be non-negative")
|
||||||
|
|
||||||
|
timestamp = timestamp or datetime.now(timezone.utc)
|
||||||
|
split_model_key(model)
|
||||||
|
key = model
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
config = self._get_config(key)
|
||||||
|
state = self.get_state(key)
|
||||||
|
|
||||||
|
state.lifetime_input_tokens += input_tokens
|
||||||
|
state.lifetime_output_tokens += output_tokens
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
self.save_state(key, state)
|
||||||
|
return
|
||||||
|
|
||||||
|
state.events.append(UsageEvent(timestamp, input_tokens, output_tokens))
|
||||||
|
state.window_input_tokens += input_tokens
|
||||||
|
state.window_output_tokens += output_tokens
|
||||||
|
|
||||||
|
self._prune_expired_events(state, config, now=timestamp)
|
||||||
|
self.save_state(key, state)
|
||||||
|
|
||||||
|
def is_rate_limited(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if the pair currently exceeds its limits."""
|
||||||
|
|
||||||
|
allowance = self.get_available_tokens(model, timestamp=timestamp)
|
||||||
|
if allowance is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
limits = [
|
||||||
|
allowance.input_tokens,
|
||||||
|
allowance.output_tokens,
|
||||||
|
allowance.total_tokens,
|
||||||
|
]
|
||||||
|
return any(limit is not None and limit <= 0 for limit in limits)
|
||||||
|
|
||||||
|
def get_available_tokens(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> TokenAllowance | None:
|
||||||
|
"""Return the current token allowance for the provider/model pair.
|
||||||
|
|
||||||
|
If there is no configuration for the pair (or a default configuration),
|
||||||
|
``None`` is returned to indicate that no limits are enforced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
split_model_key(model)
|
||||||
|
key = model
|
||||||
|
with self._lock:
|
||||||
|
config = self._get_config(key)
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
state = self.get_state(key)
|
||||||
|
self._prune_expired_events(state, config, now=timestamp)
|
||||||
|
self.save_state(key, state)
|
||||||
|
|
||||||
|
if config.max_total_tokens is None:
|
||||||
|
total_remaining = None
|
||||||
|
else:
|
||||||
|
total_remaining = config.max_total_tokens - (
|
||||||
|
state.window_input_tokens + state.window_output_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.max_input_tokens is None:
|
||||||
|
input_remaining = None
|
||||||
|
else:
|
||||||
|
input_remaining = config.max_input_tokens - state.window_input_tokens
|
||||||
|
|
||||||
|
if config.max_output_tokens is None:
|
||||||
|
output_remaining = None
|
||||||
|
else:
|
||||||
|
output_remaining = (
|
||||||
|
config.max_output_tokens - state.window_output_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return TokenAllowance(
|
||||||
|
input_tokens=clamp_non_negative(input_remaining),
|
||||||
|
output_tokens=clamp_non_negative(output_remaining),
|
||||||
|
total_tokens=clamp_non_negative(total_remaining),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_usage_breakdown(
|
||||||
|
self, provider: str | None = None, model: str | None = None
|
||||||
|
) -> dict[str, dict[str, UsageBreakdown]]:
|
||||||
|
"""Return usage statistics grouped by provider and model."""
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||||
|
for key, state in self.iter_state_items():
|
||||||
|
prov, model_name = split_model_key(key)
|
||||||
|
if provider and provider != prov:
|
||||||
|
continue
|
||||||
|
if model and model != model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
window_total = state.window_input_tokens + state.window_output_tokens
|
||||||
|
breakdown = UsageBreakdown(
|
||||||
|
window_input_tokens=state.window_input_tokens,
|
||||||
|
window_output_tokens=state.window_output_tokens,
|
||||||
|
window_total_tokens=window_total,
|
||||||
|
lifetime_input_tokens=state.lifetime_input_tokens,
|
||||||
|
lifetime_output_tokens=state.lifetime_output_tokens,
|
||||||
|
)
|
||||||
|
providers[prov][model_name] = breakdown
|
||||||
|
|
||||||
|
return providers
|
||||||
|
|
||||||
|
def iter_provider_totals(self) -> Iterable[tuple[str, UsageBreakdown]]:
|
||||||
|
"""Yield aggregated totals for each provider across its models."""
|
||||||
|
|
||||||
|
breakdowns = self.get_usage_breakdown()
|
||||||
|
for provider, models in breakdowns.items():
|
||||||
|
window_input = sum(b.window_input_tokens for b in models.values())
|
||||||
|
window_output = sum(b.window_output_tokens for b in models.values())
|
||||||
|
lifetime_input = sum(b.lifetime_input_tokens for b in models.values())
|
||||||
|
lifetime_output = sum(b.lifetime_output_tokens for b in models.values())
|
||||||
|
|
||||||
|
yield (
|
||||||
|
provider,
|
||||||
|
UsageBreakdown(
|
||||||
|
window_input_tokens=window_input,
|
||||||
|
window_output_tokens=window_output,
|
||||||
|
window_total_tokens=window_input + window_output,
|
||||||
|
lifetime_input_tokens=lifetime_input,
|
||||||
|
lifetime_output_tokens=lifetime_output,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _get_config(self, key: str) -> RateLimitConfig | None:
|
||||||
|
return self._configs.get(key) or self._default_config
|
||||||
|
|
||||||
|
def _prune_expired_events(
|
||||||
|
self,
|
||||||
|
state: UsageState,
|
||||||
|
config: RateLimitConfig,
|
||||||
|
now: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
if not state.events:
|
||||||
|
return
|
||||||
|
|
||||||
|
now = now or datetime.now(timezone.utc)
|
||||||
|
cutoff = now - config.window
|
||||||
|
|
||||||
|
for event in tuple(state.events):
|
||||||
|
if event.timestamp > cutoff:
|
||||||
|
break
|
||||||
|
state.events.popleft()
|
||||||
|
state.window_input_tokens -= event.input_tokens
|
||||||
|
state.window_output_tokens -= event.output_tokens
|
||||||
|
|
||||||
|
state.window_input_tokens = max(state.window_input_tokens, 0)
|
||||||
|
state.window_output_tokens = max(state.window_output_tokens, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryUsageTracker(UsageTracker):
|
||||||
|
"""Tracks LLM usage for providers and models within a rolling window."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: dict[str, RateLimitConfig],
|
||||||
|
default_config: RateLimitConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(configs=configs, default_config=default_config)
|
||||||
|
self._states: dict[str, UsageState] = {}
|
||||||
|
|
||||||
|
def get_state(self, key: str) -> UsageState:
|
||||||
|
return self._states.setdefault(key, UsageState())
|
||||||
|
|
||||||
|
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||||
|
return tuple(self._states.items())
|
||||||
|
|
||||||
|
|
||||||
|
def clamp_non_negative(value: int | None) -> int | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return 0 if value < 0 else value
|
||||||
|
|
||||||
@ -34,15 +34,26 @@ DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
|||||||
|
|
||||||
# Broker settings
|
# Broker settings
|
||||||
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "amqp").lower() # amqp or sqs
|
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
||||||
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "kb")
|
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||||
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", "kb")
|
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||||
|
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||||
|
CELERY_BROKER_USER = os.getenv(
|
||||||
|
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
||||||
|
)
|
||||||
|
CELERY_BROKER_PASSWORD = os.getenv(
|
||||||
|
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
||||||
if not CELERY_BROKER_HOST and CELERY_BROKER_TYPE == "amqp":
|
if not CELERY_BROKER_HOST:
|
||||||
|
if CELERY_BROKER_TYPE == "amqp":
|
||||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
||||||
|
elif CELERY_BROKER_TYPE == "redis":
|
||||||
|
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||||
|
|
||||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||||
|
|
||||||
@ -161,7 +172,7 @@ STATIC_DIR = pathlib.Path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Discord notification settings
|
# Discord notification settings
|
||||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
DISCORD_BOT_ID = int(os.getenv("DISCORD_BOT_ID", "0"))
|
||||||
DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors")
|
DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||||
DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||||
DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||||
@ -169,9 +180,7 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
|||||||
|
|
||||||
|
|
||||||
# Enable Discord notifications if bot token is set
|
# Enable Discord notifications if bot token is set
|
||||||
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
DISCORD_NOTIFICATIONS_ENABLED = boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
|
||||||
)
|
|
||||||
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
|
||||||
DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5")
|
DISCORD_MODEL = os.getenv("DISCORD_MODEL", "anthropic/claude-sonnet-4-5")
|
||||||
DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10))
|
DISCORD_MAX_TOOL_CALLS = int(os.getenv("DISCORD_MAX_TOOL_CALLS", 10))
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
|||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
)
|
)
|
||||||
|
from memory.discord.commands import register_slash_commands
|
||||||
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
|
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -199,6 +200,11 @@ class MessageCollector(commands.Bot):
|
|||||||
help_command=None, # Disable default help
|
help_command=None, # Disable default help
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def setup_hook(self):
|
||||||
|
"""Register slash commands when the bot is ready."""
|
||||||
|
|
||||||
|
register_slash_commands(self)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""Called when bot connects to Discord"""
|
"""Called when bot connects to Discord"""
|
||||||
logger.info(f"Discord collector connected as {self.user}")
|
logger.info(f"Discord collector connected as {self.user}")
|
||||||
@ -207,6 +213,11 @@ class MessageCollector(commands.Bot):
|
|||||||
# Sync server and channel metadata
|
# Sync server and channel metadata
|
||||||
await self.sync_servers_and_channels()
|
await self.sync_servers_and_channels()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.tree.sync()
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.error("Failed to sync slash commands: %s", exc)
|
||||||
|
|
||||||
logger.info("Discord message collector ready")
|
logger.info("Discord message collector ready")
|
||||||
|
|
||||||
async def on_message(self, message: discord.Message):
|
async def on_message(self, message: discord.Message):
|
||||||
|
|||||||
393
src/memory/discord/commands.py
Normal file
393
src/memory/discord/commands.py
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
"""Lightweight slash-command helpers for the Discord collector."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Literal
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||||
|
|
||||||
|
ScopeLiteral = Literal["server", "channel", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
class CommandError(Exception):
|
||||||
|
"""Raised when a user-facing error occurs while handling a command."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CommandResponse:
|
||||||
|
"""Value object returned by handlers."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
ephemeral: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CommandContext:
|
||||||
|
"""All information a handler needs to fulfil a command."""
|
||||||
|
|
||||||
|
session: Session
|
||||||
|
interaction: discord.Interaction
|
||||||
|
actor: DiscordUser
|
||||||
|
scope: ScopeLiteral
|
||||||
|
target: DiscordServer | DiscordChannel | DiscordUser
|
||||||
|
display_name: str
|
||||||
|
|
||||||
|
|
||||||
|
CommandHandler = Callable[..., CommandResponse]
|
||||||
|
|
||||||
|
|
||||||
|
def register_slash_commands(bot: discord.Client) -> None:
|
||||||
|
"""Register the collector slash commands on the provided bot."""
|
||||||
|
|
||||||
|
if getattr(bot, "_memory_commands_registered", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
setattr(bot, "_memory_commands_registered", True)
|
||||||
|
|
||||||
|
if not hasattr(bot, "tree"):
|
||||||
|
raise RuntimeError("Bot instance does not support app commands")
|
||||||
|
|
||||||
|
tree = bot.tree
|
||||||
|
|
||||||
|
@tree.command(name="memory_prompt", description="Show the current system prompt")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
scope="Which configuration to inspect",
|
||||||
|
user="Target user when the scope is 'user'",
|
||||||
|
)
|
||||||
|
async def prompt_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
user: discord.User | None = None,
|
||||||
|
) -> None:
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_prompt,
|
||||||
|
target_user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tree.command(
|
||||||
|
name="memory_chattiness",
|
||||||
|
description="Show or update the chattiness threshold for the target",
|
||||||
|
)
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
scope="Which configuration to inspect",
|
||||||
|
value="Optional new threshold value between 0 and 100",
|
||||||
|
user="Target user when the scope is 'user'",
|
||||||
|
)
|
||||||
|
async def chattiness_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
value: int | None = None,
|
||||||
|
user: discord.User | None = None,
|
||||||
|
) -> None:
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_chattiness,
|
||||||
|
target_user=user,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tree.command(
|
||||||
|
name="memory_ignore",
|
||||||
|
description="Toggle whether the bot should ignore messages for the target",
|
||||||
|
)
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
scope="Which configuration to modify",
|
||||||
|
enabled="Optional flag. Leave empty to enable ignoring.",
|
||||||
|
user="Target user when the scope is 'user'",
|
||||||
|
)
|
||||||
|
async def ignore_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
enabled: bool | None = None,
|
||||||
|
user: discord.User | None = None,
|
||||||
|
) -> None:
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_ignore,
|
||||||
|
target_user=user,
|
||||||
|
ignore_enabled=enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tree.command(name="memory_summary", description="Show the stored summary for the target")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
scope="Which configuration to inspect",
|
||||||
|
user="Target user when the scope is 'user'",
|
||||||
|
)
|
||||||
|
async def summary_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
user: discord.User | None = None,
|
||||||
|
) -> None:
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_summary,
|
||||||
|
target_user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_interaction_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
*,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
handler: CommandHandler,
|
||||||
|
target_user: discord.User | None = None,
|
||||||
|
**handler_kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Shared coroutine used by the registered slash commands."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
with make_session() as session:
|
||||||
|
response = run_command(
|
||||||
|
session,
|
||||||
|
interaction,
|
||||||
|
scope,
|
||||||
|
handler=handler,
|
||||||
|
target_user=target_user,
|
||||||
|
**handler_kwargs,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
except CommandError as exc: # pragma: no cover - passthrough
|
||||||
|
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
await interaction.response.send_message(
|
||||||
|
response.content,
|
||||||
|
ephemeral=response.ephemeral,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(
|
||||||
|
session: Session,
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
*,
|
||||||
|
handler: CommandHandler,
|
||||||
|
target_user: discord.User | None = None,
|
||||||
|
**handler_kwargs,
|
||||||
|
) -> CommandResponse:
|
||||||
|
"""Create a :class:`CommandContext` and execute the handler."""
|
||||||
|
|
||||||
|
context = _build_context(session, interaction, scope, target_user)
|
||||||
|
return handler(context, **handler_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_context(
|
||||||
|
session: Session,
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
target_user: discord.User | None,
|
||||||
|
) -> CommandContext:
|
||||||
|
actor = _ensure_user(session, interaction.user)
|
||||||
|
|
||||||
|
if scope == "server":
|
||||||
|
if interaction.guild is None:
|
||||||
|
raise CommandError("This command can only be used inside a server.")
|
||||||
|
|
||||||
|
target = _ensure_server(session, interaction.guild)
|
||||||
|
display_name = f"server **{target.name}**"
|
||||||
|
return CommandContext(
|
||||||
|
session=session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=actor,
|
||||||
|
scope=scope,
|
||||||
|
target=target,
|
||||||
|
display_name=display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if scope == "channel":
|
||||||
|
channel_obj = interaction.channel
|
||||||
|
if channel_obj is None or not hasattr(channel_obj, "id"):
|
||||||
|
raise CommandError("Unable to determine channel for this interaction.")
|
||||||
|
|
||||||
|
target = _ensure_channel(session, channel_obj, interaction.guild_id)
|
||||||
|
display_name = f"channel **#{target.name}**"
|
||||||
|
return CommandContext(
|
||||||
|
session=session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=actor,
|
||||||
|
scope=scope,
|
||||||
|
target=target,
|
||||||
|
display_name=display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if scope == "user":
|
||||||
|
discord_user = target_user or interaction.user
|
||||||
|
if discord_user is None:
|
||||||
|
raise CommandError("A target user is required for this command.")
|
||||||
|
|
||||||
|
target = _ensure_user(session, discord_user)
|
||||||
|
display_name = target.display_name or target.username
|
||||||
|
return CommandContext(
|
||||||
|
session=session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=actor,
|
||||||
|
scope=scope,
|
||||||
|
target=target,
|
||||||
|
display_name=f"user **{display_name}**",
|
||||||
|
)
|
||||||
|
|
||||||
|
raise CommandError(f"Unsupported scope '{scope}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||||
|
server = session.get(DiscordServer, guild.id)
|
||||||
|
if server is None:
|
||||||
|
server = DiscordServer(
|
||||||
|
id=guild.id,
|
||||||
|
name=guild.name or f"Server {guild.id}",
|
||||||
|
description=getattr(guild, "description", None),
|
||||||
|
member_count=getattr(guild, "member_count", None),
|
||||||
|
)
|
||||||
|
session.add(server)
|
||||||
|
session.flush()
|
||||||
|
else:
|
||||||
|
if guild.name and server.name != guild.name:
|
||||||
|
server.name = guild.name
|
||||||
|
description = getattr(guild, "description", None)
|
||||||
|
if description and server.description != description:
|
||||||
|
server.description = description
|
||||||
|
member_count = getattr(guild, "member_count", None)
|
||||||
|
if member_count is not None:
|
||||||
|
server.member_count = member_count
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_channel(
|
||||||
|
session: Session,
|
||||||
|
channel: discord.abc.Messageable,
|
||||||
|
guild_id: int | None,
|
||||||
|
) -> DiscordChannel:
|
||||||
|
channel_id = getattr(channel, "id", None)
|
||||||
|
if channel_id is None:
|
||||||
|
raise CommandError("Channel is missing an identifier.")
|
||||||
|
|
||||||
|
channel_model = session.get(DiscordChannel, channel_id)
|
||||||
|
if channel_model is None:
|
||||||
|
channel_model = DiscordChannel(
|
||||||
|
id=channel_id,
|
||||||
|
server_id=guild_id,
|
||||||
|
name=getattr(channel, "name", f"Channel {channel_id}"),
|
||||||
|
channel_type=_resolve_channel_type(channel),
|
||||||
|
)
|
||||||
|
session.add(channel_model)
|
||||||
|
session.flush()
|
||||||
|
else:
|
||||||
|
name = getattr(channel, "name", None)
|
||||||
|
if name and channel_model.name != name:
|
||||||
|
channel_model.name = name
|
||||||
|
|
||||||
|
return channel_model
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||||
|
user = session.get(DiscordUser, discord_user.id)
|
||||||
|
display_name = getattr(discord_user, "display_name", discord_user.name)
|
||||||
|
if user is None:
|
||||||
|
user = DiscordUser(
|
||||||
|
id=discord_user.id,
|
||||||
|
username=discord_user.name,
|
||||||
|
display_name=display_name,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
session.flush()
|
||||||
|
else:
|
||||||
|
if user.username != discord_user.name:
|
||||||
|
user.username = discord_user.name
|
||||||
|
if display_name and user.display_name != display_name:
|
||||||
|
user.display_name = display_name
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
|
||||||
|
if isinstance(channel, discord.DMChannel):
|
||||||
|
return "dm"
|
||||||
|
if isinstance(channel, discord.GroupChannel):
|
||||||
|
return "group_dm"
|
||||||
|
if isinstance(channel, discord.Thread):
|
||||||
|
return "thread"
|
||||||
|
if isinstance(channel, discord.VoiceChannel):
|
||||||
|
return "voice"
|
||||||
|
if isinstance(channel, discord.TextChannel):
|
||||||
|
return "text"
|
||||||
|
return getattr(getattr(channel, "type", None), "name", "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_prompt(context: CommandContext) -> CommandResponse:
|
||||||
|
prompt = getattr(context.target, "system_prompt", None)
|
||||||
|
|
||||||
|
if prompt:
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"No prompt configured for {context.display_name}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_chattiness(
|
||||||
|
context: CommandContext,
|
||||||
|
*,
|
||||||
|
value: int | None,
|
||||||
|
) -> CommandResponse:
|
||||||
|
model = context.target
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return CommandResponse(
|
||||||
|
content=(
|
||||||
|
f"Chattiness threshold for {context.display_name}: "
|
||||||
|
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not 0 <= value <= 100:
|
||||||
|
raise CommandError("Chattiness threshold must be between 0 and 100.")
|
||||||
|
|
||||||
|
setattr(model, "chattiness_threshold", value)
|
||||||
|
|
||||||
|
return CommandResponse(
|
||||||
|
content=(
|
||||||
|
f"Updated chattiness threshold for {context.display_name} "
|
||||||
|
f"to {value}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_ignore(
|
||||||
|
context: CommandContext,
|
||||||
|
*,
|
||||||
|
ignore_enabled: bool | None,
|
||||||
|
) -> CommandResponse:
|
||||||
|
model = context.target
|
||||||
|
new_value = True if ignore_enabled is None else ignore_enabled
|
||||||
|
setattr(model, "ignore_messages", new_value)
|
||||||
|
|
||||||
|
verb = "now ignoring" if new_value else "no longer ignoring"
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"The bot is {verb} messages for {context.display_name}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_summary(context: CommandContext) -> CommandResponse:
|
||||||
|
summary = getattr(context.target, "summary", None)
|
||||||
|
|
||||||
|
if summary:
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"Summary for {context.display_name}:\n\n{summary}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"No summary stored for {context.display_name}.",
|
||||||
|
)
|
||||||
147
tests/memory/common/llms/test_usage_tracker.py
Normal file
147
tests/memory/common/llms/test_usage_tracker.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from memory.common.llms.usage_tracker import (
|
||||||
|
InMemoryUsageTracker,
|
||||||
|
RateLimitConfig,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tracker() -> InMemoryUsageTracker:
|
||||||
|
config = RateLimitConfig(
|
||||||
|
window=timedelta(minutes=1),
|
||||||
|
max_input_tokens=1_000,
|
||||||
|
max_output_tokens=2_000,
|
||||||
|
max_total_tokens=2_500,
|
||||||
|
)
|
||||||
|
return InMemoryUsageTracker(
|
||||||
|
{
|
||||||
|
"anthropic/claude-3": config,
|
||||||
|
"anthropic/haiku": config,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"window, kwargs",
|
||||||
|
[
|
||||||
|
(timedelta(minutes=1), {}),
|
||||||
|
(timedelta(seconds=0), {"max_total_tokens": 1}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RateLimitConfig(window=window, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
|
|
||||||
|
allowance = tracker.get_available_tokens(
|
||||||
|
"anthropic/claude-3", timestamp=now
|
||||||
|
)
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.input_tokens == 900
|
||||||
|
assert allowance.output_tokens == 1_800
|
||||||
|
assert allowance.total_tokens == 2_200
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limited_when_over_budget(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||||
|
|
||||||
|
assert tracker.is_rate_limited("anthropic/claude-3", timestamp=now)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||||
|
|
||||||
|
later = now + timedelta(minutes=2)
|
||||||
|
allowance = tracker.get_available_tokens(
|
||||||
|
"anthropic/claude-3", timestamp=later
|
||||||
|
)
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.input_tokens == 1_000
|
||||||
|
assert allowance.output_tokens == 2_000
|
||||||
|
assert allowance.total_tokens == 2_500
|
||||||
|
assert not tracker.is_rate_limited("anthropic/claude-3", timestamp=later)
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
|
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||||
|
|
||||||
|
breakdown = tracker.get_usage_breakdown()
|
||||||
|
assert "anthropic" in breakdown
|
||||||
|
assert "claude-3" in breakdown["anthropic"]
|
||||||
|
claude_usage = breakdown["anthropic"]["claude-3"]
|
||||||
|
assert claude_usage.window_input_tokens == 100
|
||||||
|
assert claude_usage.window_output_tokens == 200
|
||||||
|
|
||||||
|
provider_totals = dict(tracker.iter_provider_totals())
|
||||||
|
anthropic_totals = provider_totals["anthropic"]
|
||||||
|
assert anthropic_totals.window_input_tokens == 150
|
||||||
|
assert anthropic_totals.window_output_tokens == 275
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
|
||||||
|
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
|
||||||
|
|
||||||
|
filtered = tracker.get_usage_breakdown(provider="anthropic")
|
||||||
|
assert set(filtered.keys()) == {"anthropic"}
|
||||||
|
assert set(filtered["anthropic"].keys()) == {"claude-3"}
|
||||||
|
|
||||||
|
filtered_model = tracker.get_usage_breakdown(model="gpt-4o")
|
||||||
|
assert set(filtered_model.keys()) == {"openai"}
|
||||||
|
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_configuration_records_lifetime_only() -> None:
|
||||||
|
tracker = InMemoryUsageTracker(configs={})
|
||||||
|
tracker.record_usage("openai/gpt-4o", 10, 20)
|
||||||
|
|
||||||
|
assert tracker.get_available_tokens("openai/gpt-4o") is None
|
||||||
|
|
||||||
|
breakdown = tracker.get_usage_breakdown()
|
||||||
|
usage = breakdown["openai"]["gpt-4o"]
|
||||||
|
assert usage.window_input_tokens == 0
|
||||||
|
assert usage.lifetime_input_tokens == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_configuration_is_used() -> None:
|
||||||
|
default = RateLimitConfig(window=timedelta(minutes=1), max_total_tokens=100)
|
||||||
|
tracker = InMemoryUsageTracker(configs={}, default_config=default)
|
||||||
|
|
||||||
|
tracker.record_usage("anthropic/claude-3", 10, 10)
|
||||||
|
allowance = tracker.get_available_tokens("anthropic/claude-3")
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.total_tokens == 80
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_usage_rejects_negative_values(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tracker.record_usage("anthropic/claude-3", -1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
||||||
|
config = RateLimitConfig(window=timedelta(minutes=1), max_output_tokens=50)
|
||||||
|
tracker = InMemoryUsageTracker({"openai/gpt-4o": config})
|
||||||
|
|
||||||
|
tracker.record_usage("openai/gpt-4o", 0, 50)
|
||||||
|
assert tracker.is_rate_limited("openai/gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_tracker_base_not_instantiable() -> None:
|
||||||
|
class DummyTracker(UsageTracker):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
DummyTracker({}).record_usage("provider/model", 1, 1)
|
||||||
@ -493,10 +493,12 @@ async def test_on_ready():
|
|||||||
collector.user.name = "TestBot"
|
collector.user.name = "TestBot"
|
||||||
collector.guilds = [Mock(), Mock()]
|
collector.guilds = [Mock(), Mock()]
|
||||||
collector.sync_servers_and_channels = AsyncMock()
|
collector.sync_servers_and_channels = AsyncMock()
|
||||||
|
collector.tree.sync = AsyncMock()
|
||||||
|
|
||||||
await collector.on_ready()
|
await collector.on_ready()
|
||||||
|
|
||||||
collector.sync_servers_and_channels.assert_called_once()
|
collector.sync_servers_and_channels.assert_called_once()
|
||||||
|
collector.tree.sync.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
172
tests/memory/discord_tests/test_commands.py
Normal file
172
tests/memory/discord_tests/test_commands.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||||
|
from memory.discord.commands import (
|
||||||
|
CommandError,
|
||||||
|
CommandResponse,
|
||||||
|
run_command,
|
||||||
|
handle_prompt,
|
||||||
|
handle_chattiness,
|
||||||
|
handle_ignore,
|
||||||
|
handle_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyInteraction:
|
||||||
|
"""Lightweight stand-in for :class:`discord.Interaction` used in tests."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
guild: discord.Guild | None,
|
||||||
|
channel: discord.abc.Messageable | None,
|
||||||
|
user: discord.abc.User,
|
||||||
|
) -> None:
|
||||||
|
self.guild = guild
|
||||||
|
self.channel = channel
|
||||||
|
self.user = user
|
||||||
|
self.guild_id = getattr(guild, "id", None)
|
||||||
|
self.channel_id = getattr(channel, "id", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def guild() -> discord.Guild:
|
||||||
|
guild = MagicMock(spec=discord.Guild)
|
||||||
|
guild.id = 123
|
||||||
|
guild.name = "Test Guild"
|
||||||
|
guild.description = "Guild description"
|
||||||
|
guild.member_count = 42
|
||||||
|
return guild
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text_channel(guild: discord.Guild) -> discord.TextChannel:
|
||||||
|
channel = MagicMock(spec=discord.TextChannel)
|
||||||
|
channel.id = 456
|
||||||
|
channel.name = "general"
|
||||||
|
channel.guild = guild
|
||||||
|
channel.type = discord.ChannelType.text
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def discord_user() -> discord.User:
|
||||||
|
user = MagicMock(spec=discord.User)
|
||||||
|
user.id = 789
|
||||||
|
user.name = "command-user"
|
||||||
|
user.display_name = "Commander"
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def interaction(guild, text_channel, discord_user) -> DummyInteraction:
|
||||||
|
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||||
|
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = run_command(
|
||||||
|
db_session,
|
||||||
|
interaction,
|
||||||
|
scope="server",
|
||||||
|
handler=handle_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, CommandResponse)
|
||||||
|
assert "Be helpful" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
|
||||||
|
response = run_command(
|
||||||
|
db_session,
|
||||||
|
interaction,
|
||||||
|
scope="channel",
|
||||||
|
handler=handle_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "No prompt" in response.content
|
||||||
|
channel = db_session.get(DiscordChannel, text_channel.id)
|
||||||
|
assert channel is not None
|
||||||
|
assert channel.name == text_channel.name
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||||
|
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = run_command(
|
||||||
|
db_session,
|
||||||
|
interaction,
|
||||||
|
scope="server",
|
||||||
|
handler=handle_chattiness,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert str(server.chattiness_threshold) in response.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_command_chattiness_update(db_session, interaction):
|
||||||
|
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
|
||||||
|
db_session.add(user_model)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = run_command(
|
||||||
|
db_session,
|
||||||
|
interaction,
|
||||||
|
scope="user",
|
||||||
|
handler=handle_chattiness,
|
||||||
|
value=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
assert "Updated" in response.content
|
||||||
|
assert user_model.chattiness_threshold == 80
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||||
|
with pytest.raises(CommandError):
|
||||||
|
run_command(
|
||||||
|
db_session,
|
||||||
|
interaction,
|
||||||
|
scope="user",
|
||||||
|
handler=handle_chattiness,
|
||||||
|
value=150,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||||
|
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
|
||||||
|
db_session.add(channel)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = run_command(
|
||||||
|
db_session,
|
||||||
|
interaction,
|
||||||
|
scope="channel",
|
||||||
|
handler=handle_ignore,
|
||||||
|
ignore_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
assert "no longer" not in response.content
|
||||||
|
assert channel.ignore_messages is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_command_summary_missing(db_session, interaction):
|
||||||
|
response = run_command(
|
||||||
|
db_session,
|
||||||
|
interaction,
|
||||||
|
scope="user",
|
||||||
|
handler=handle_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "No summary" in response.content
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user