diff --git a/src/memory/common/llms/__init__.py b/src/memory/common/llms/__init__.py index 9255f63..b2d37cc 100644 --- a/src/memory/common/llms/__init__.py +++ b/src/memory/common/llms/__init__.py @@ -24,13 +24,6 @@ from memory.common.llms.base import ( ) from memory.common.llms.anthropic_provider import AnthropicProvider from memory.common.llms.openai_provider import OpenAIProvider -from memory.common.llms.usage_tracker import ( - InMemoryUsageTracker, - RateLimitConfig, - TokenAllowance, - UsageBreakdown, - UsageTracker, -) from memory.common import tokens __all__ = [ @@ -49,11 +42,6 @@ __all__ = [ "StreamEvent", "LLMSettings", "create_provider", - "InMemoryUsageTracker", - "RateLimitConfig", - "TokenAllowance", - "UsageBreakdown", - "UsageTracker", ] logger = logging.getLogger(__name__) @@ -93,28 +81,3 @@ def truncate(content: str, target_tokens: int) -> str: if len(content) > target_chars: return content[:target_chars].rsplit(" ", 1)[0] + "..." return content - - -# bla = 1 -# from memory.common.llms import * -# from memory.common.llms.tools.discord import make_discord_tools -# from memory.common.db.connection import make_session -# from memory.common.db.models import * - -# model = "anthropic/claude-sonnet-4-5" -# provider = create_provider(model=model) -# with make_session() as session: -# bot = session.query(DiscordBotUser).first() -# server = session.query(DiscordServer).first() -# channel = server.channels[0] -# tools = make_discord_tools(bot, None, channel, model) - -# def demo(msg: str): -# messages = [ -# Message( -# role=MessageRole.USER, -# content=msg, -# ) -# ] -# for m in provider.stream_with_tools(messages, tools): -# print(m) diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index f62eb4d..709d810 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -12,6 +12,7 @@ from PIL import Image from memory.common import settings from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult +from memory.common.llms.usage import UsageTracker, RedisUsageTracker logger = logging.getLogger(__name__) @@ -204,7 +205,9 @@ class LLMSettings: class BaseLLMProvider(ABC): """Base class for LLM providers.""" - def __init__(self, api_key: str, model: str): + def __init__( + self, api_key: str, model: str, usage_tracker: UsageTracker | None = None + ): """ Initialize the LLM provider. @@ -215,6 +218,7 @@ class BaseLLMProvider(ABC): self.api_key = api_key self.model = model self._client: Any = None + self.usage_tracker: UsageTracker = usage_tracker or RedisUsageTracker() @abstractmethod def _initialize_client(self) -> Any: diff --git a/src/memory/common/llms/usage/__init__.py b/src/memory/common/llms/usage/__init__.py new file mode 100644 index 0000000..fde43e1 --- /dev/null +++ b/src/memory/common/llms/usage/__init__.py @@ -0,0 +1,17 @@ +from memory.common.llms.usage.redis_usage_tracker import RedisUsageTracker +from memory.common.llms.usage.usage_tracker import ( + InMemoryUsageTracker, + RateLimitConfig, + TokenAllowance, + UsageBreakdown, + UsageTracker, +) + +__all__ = [ + "InMemoryUsageTracker", + "RateLimitConfig", + "RedisUsageTracker", + "TokenAllowance", + "UsageBreakdown", + "UsageTracker", +] diff --git a/src/memory/common/llms/usage/redis_usage_tracker.py b/src/memory/common/llms/usage/redis_usage_tracker.py new file mode 100644 index 0000000..468e9da --- /dev/null +++ b/src/memory/common/llms/usage/redis_usage_tracker.py @@ -0,0 +1,88 @@ +"""Redis-backed usage tracker implementation.""" + +import json +from typing import Any, Iterable, Protocol + +import redis + +from memory.common import settings +from memory.common.llms.usage.usage_tracker import ( + RateLimitConfig, + UsageState, + UsageTracker, +) + + +class RedisClientProtocol(Protocol): + def get(self, key: str) -> Any: # pragma: no cover - Protocol definition + ... + + def set( + self, key: str, value: Any + ) -> Any: # pragma: no cover - Protocol definition + ... + + def scan_iter( + self, match: str + ) -> Iterable[Any]: # pragma: no cover - Protocol definition + ... + + +class RedisUsageTracker(UsageTracker): + """Tracks LLM usage for providers and models using Redis for persistence.""" + + def __init__( + self, + configs: dict[str, RateLimitConfig], + default_config: RateLimitConfig | None = None, + *, + redis_client: RedisClientProtocol | None = None, + key_prefix: str | None = None, + ) -> None: + super().__init__(configs=configs, default_config=default_config) + if redis_client is None: + redis_client = redis.Redis( + host=settings.REDIS_HOST, + port=int(settings.REDIS_PORT), + db=int(settings.REDIS_DB), + decode_responses=False, + ) + self._redis = redis_client + prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX + self._key_prefix = prefix.rstrip(":") + + def get_state(self, model: str) -> UsageState: + redis_key = self._format_key(model) + payload = self._redis.get(redis_key) + if not payload: + return UsageState() + if isinstance(payload, bytes): + payload = payload.decode() + return UsageState.from_payload(json.loads(payload)) + + def iter_state_items(self) -> Iterable[tuple[str, UsageState]]: + pattern = f"{self._key_prefix}:*" + for redis_key in self._redis.scan_iter(match=pattern): + key = self._ensure_str(redis_key) + payload = self._redis.get(key) + if not payload: + continue + if isinstance(payload, bytes): + payload = payload.decode() + state = UsageState.from_payload(json.loads(payload)) + yield key[len(self._key_prefix) + 1 :], state + + def save_state(self, model: str, state: UsageState) -> None: + redis_key = self._format_key(model) + self._redis.set( + redis_key, json.dumps(state.to_payload(), separators=(",", ":")) + ) + + def _format_key(self, model: str) -> str: + return f"{self._key_prefix}:{model}" + + @staticmethod + def _ensure_str(value: Any) -> str: + if isinstance(value, bytes): + return value.decode() + return str(value) diff --git a/src/memory/common/llms/usage_tracker.py b/src/memory/common/llms/usage/usage_tracker.py similarity index 87% rename from src/memory/common/llms/usage_tracker.py rename to src/memory/common/llms/usage/usage_tracker.py index 90b1fb6..fbf9750 100644 --- a/src/memory/common/llms/usage_tracker.py +++ b/src/memory/common/llms/usage/usage_tracker.py @@ -5,6 +5,7 @@ from collections.abc import Iterable from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from threading import Lock +from typing import Any @dataclass(frozen=True) @@ -45,6 +46,40 @@ class UsageState: lifetime_input_tokens: int = 0 lifetime_output_tokens: int = 0 + def to_payload(self) -> dict[str, Any]: + return { + "events": [ + { + "timestamp": event.timestamp.isoformat(), + "input_tokens": event.input_tokens, + "output_tokens": event.output_tokens, + } + for event in self.events + ], + "window_input_tokens": self.window_input_tokens, + "window_output_tokens": self.window_output_tokens, + "lifetime_input_tokens": self.lifetime_input_tokens, + "lifetime_output_tokens": self.lifetime_output_tokens, + } + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> "UsageState": + events = deque( + UsageEvent( + timestamp=datetime.fromisoformat(event["timestamp"]), + input_tokens=event["input_tokens"], + output_tokens=event["output_tokens"], + ) + for event in payload.get("events", []) + ) + return cls( + events=events, + window_input_tokens=payload.get("window_input_tokens", 0), + window_output_tokens=payload.get("window_output_tokens", 0), + lifetime_input_tokens=payload.get("lifetime_input_tokens", 0), + lifetime_output_tokens=payload.get("lifetime_output_tokens", 0), + ) + @dataclass class TokenAllowance: @@ -76,9 +111,7 @@ class UsageBreakdown: def split_model_key(model: str) -> tuple[str, str]: if "/" not in model: - raise ValueError( - "model must be formatted as '/'" - ) + raise ValueError("model must be formatted as '/'") provider, model_name = model.split("/", maxsplit=1) if not provider or not model_name: @@ -205,9 +238,7 @@ class UsageTracker: if config.max_output_tokens is None: output_remaining = None else: - output_remaining = ( - config.max_output_tokens - state.window_output_tokens - ) + output_remaining = config.max_output_tokens - state.window_output_tokens return TokenAllowance( input_tokens=clamp_non_negative(input_remaining), @@ -313,4 +344,3 @@ def clamp_non_negative(value: int | None) -> int | None: if value is None: return None return 0 if value < 0 else value - diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 842a32c..9ea8c64 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -38,6 +38,7 @@ CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower() REDIS_HOST = os.getenv("REDIS_HOST", "redis") REDIS_PORT = os.getenv("REDIS_PORT", "6379") REDIS_DB = os.getenv("REDIS_DB", "0") +LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage") CELERY_BROKER_USER = os.getenv( "CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else "" ) diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index 8f6cfb4..102fbd7 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -58,6 +58,13 @@ def call_llm( msgs: list[str] = [], allowed_tools: list[str] = [], ) -> str | None: + provider = create_provider(model=model) + if provider.usage_tracker.is_rate_limited(model): + logger.error( + f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}" + ) + return None + tools = make_discord_tools( message.recipient_user.system_user, message.from_user, @@ -73,7 +80,6 @@ def call_llm( system_prompt += comm_channel_prompt( session, message.recipient_user, message.channel ) - provider = create_provider(model=model) messages = previous_messages( session, message.recipient_user and message.recipient_user.id, diff --git a/tests/memory/common/llms/test_usage_tracker.py b/tests/memory/common/llms/test_usage_tracker.py index e0df8ce..dbb9bc1 100644 --- a/tests/memory/common/llms/test_usage_tracker.py +++ b/tests/memory/common/llms/test_usage_tracker.py @@ -1,7 +1,24 @@ from datetime import datetime, timedelta, timezone +from typing import Iterable import pytest +try: + import redis # noqa: F401 # pragma: no cover - optional test dependency +except ModuleNotFoundError: # pragma: no cover - import guard for test envs + import sys + from types import SimpleNamespace + + class _RedisStub(SimpleNamespace): + class Redis: # type: ignore[no-redef] + def __init__(self, *args: object, **kwargs: object) -> None: + raise ModuleNotFoundError( + "The 'redis' package is required to use RedisUsageTracker" + ) + + sys.modules.setdefault("redis", _RedisStub()) + +from memory.common.llms.redis_usage_tracker import RedisUsageTracker from memory.common.llms.usage_tracker import ( InMemoryUsageTracker, RateLimitConfig, @@ -9,6 +26,24 @@ from memory.common.llms.usage_tracker import ( ) +class FakeRedis: + def __init__(self) -> None: + self._store: dict[str, str] = {} + + def get(self, key: str) -> str | None: + return self._store.get(key) + + def set(self, key: str, value: str) -> None: + self._store[key] = value + + def scan_iter(self, match: str) -> Iterable[str]: + from fnmatch import fnmatch + + for key in list(self._store.keys()): + if fnmatch(key, match): + yield key + + @pytest.fixture def tracker() -> InMemoryUsageTracker: config = RateLimitConfig( @@ -25,6 +60,23 @@ def tracker() -> InMemoryUsageTracker: ) +@pytest.fixture +def redis_tracker() -> RedisUsageTracker: + config = RateLimitConfig( + window=timedelta(minutes=1), + max_input_tokens=1_000, + max_output_tokens=2_000, + max_total_tokens=2_500, + ) + return RedisUsageTracker( + { + "anthropic/claude-3": config, + "anthropic/haiku": config, + }, + redis_client=FakeRedis(), + ) + + @pytest.mark.parametrize( "window, kwargs", [ @@ -139,6 +191,22 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None: assert tracker.is_rate_limited("openai/gpt-4o") +def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None: + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now) + redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now) + + allowance = redis_tracker.get_available_tokens("anthropic/claude-3", timestamp=now) + assert allowance is not None + assert allowance.input_tokens == 900 + + breakdown = redis_tracker.get_usage_breakdown() + assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200 + + items = dict(redis_tracker.iter_state_items()) + assert set(items.keys()) == {"anthropic/claude-3", "anthropic/haiku"} + + def test_usage_tracker_base_not_instantiable() -> None: class DummyTracker(UsageTracker): pass