From 07852f9ee791033f7f5d082f2c81b7fdc166e0b0 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sat, 1 Nov 2025 16:22:40 +0100 Subject: [PATCH] Base usage tracker --- AGENTS.md | 5 + src/memory/common/llms/__init__.py | 12 + src/memory/common/llms/usage_tracker.py | 316 ++++++++++++++++++ .../memory/common/llms/test_usage_tracker.py | 147 ++++++++ 4 files changed, 480 insertions(+) create mode 100644 AGENTS.md create mode 100644 src/memory/common/llms/usage_tracker.py create mode 100644 tests/memory/common/llms/test_usage_tracker.py diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..00c6668 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,5 @@ +# Agent Guidance + +- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary. +- Treat LLM model identifiers as `/` strings throughout the codebase. +- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved. diff --git a/src/memory/common/llms/__init__.py b/src/memory/common/llms/__init__.py index b84e337..9255f63 100644 --- a/src/memory/common/llms/__init__.py +++ b/src/memory/common/llms/__init__.py @@ -24,6 +24,13 @@ from memory.common.llms.base import ( ) from memory.common.llms.anthropic_provider import AnthropicProvider from memory.common.llms.openai_provider import OpenAIProvider +from memory.common.llms.usage_tracker import ( + InMemoryUsageTracker, + RateLimitConfig, + TokenAllowance, + UsageBreakdown, + UsageTracker, +) from memory.common import tokens __all__ = [ @@ -42,6 +49,11 @@ __all__ = [ "StreamEvent", "LLMSettings", "create_provider", + "InMemoryUsageTracker", + "RateLimitConfig", + "TokenAllowance", + "UsageBreakdown", + "UsageTracker", ] logger = logging.getLogger(__name__) diff --git a/src/memory/common/llms/usage_tracker.py b/src/memory/common/llms/usage_tracker.py new file mode 100644 index 0000000..90b1fb6 --- /dev/null +++ b/src/memory/common/llms/usage_tracker.py @@ -0,0 +1,316 @@ +"""LLM usage tracking utilities.""" + +from collections import defaultdict, deque +from collections.abc import Iterable +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from threading import Lock + + +@dataclass(frozen=True) +class RateLimitConfig: + """Configuration for a single rolling usage window.""" + + window: timedelta + max_input_tokens: int | None = None + max_output_tokens: int | None = None + max_total_tokens: int | None = None + + def __post_init__(self) -> None: + if self.window <= timedelta(0): + raise ValueError("window must be positive") + if ( + self.max_input_tokens is None + and self.max_output_tokens is None + and self.max_total_tokens is None + ): + raise ValueError( + "At least one of max_input_tokens, max_output_tokens or " + "max_total_tokens must be provided" + ) + + +@dataclass +class UsageEvent: + timestamp: datetime + input_tokens: int + output_tokens: int + + +@dataclass +class UsageState: + events: deque[UsageEvent] = field(default_factory=deque) + window_input_tokens: int = 0 + window_output_tokens: int = 0 + lifetime_input_tokens: int = 0 + lifetime_output_tokens: int = 0 + + +@dataclass +class TokenAllowance: + """Represents the tokens that can be consumed right now.""" + + input_tokens: int | None + output_tokens: int | None + total_tokens: int | None + + +@dataclass +class UsageBreakdown: + """Detailed usage statistics for a provider/model pair.""" + + window_input_tokens: int + window_output_tokens: int + window_total_tokens: int + lifetime_input_tokens: int + lifetime_output_tokens: int + + @property + def window_total(self) -> int: + return self.window_total_tokens + + @property + def lifetime_total_tokens(self) -> int: + return self.lifetime_input_tokens + self.lifetime_output_tokens + + +def split_model_key(model: str) -> tuple[str, str]: + if "/" not in model: + raise ValueError( + "model must be formatted as '/'" + ) + + provider, model_name = model.split("/", maxsplit=1) + if not provider or not model_name: + raise ValueError( + "model must include both provider and model name separated by '/'" + ) + return provider, model_name + + +class UsageTracker: + """Base class for usage trackers that operate on provider/model keys.""" + + def __init__( + self, + configs: dict[str, RateLimitConfig], + default_config: RateLimitConfig | None = None, + ) -> None: + self._configs = configs + self._default_config = default_config + self._lock = Lock() + + # ------------------------------------------------------------------ + # Storage hooks + # ------------------------------------------------------------------ + def get_state(self, key: str) -> UsageState: + raise NotImplementedError + + def iter_state_items(self) -> Iterable[tuple[str, UsageState]]: + raise NotImplementedError + + def save_state(self, key: str, state: UsageState) -> None: + """Persist the given state back to the underlying store.""" + del key, state + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def record_usage( + self, + model: str, + input_tokens: int, + output_tokens: int, + timestamp: datetime | None = None, + ) -> None: + """Record token usage for the given provider/model pair.""" + + if input_tokens < 0 or output_tokens < 0: + raise ValueError("Token counts must be non-negative") + + timestamp = timestamp or datetime.now(timezone.utc) + split_model_key(model) + key = model + + with self._lock: + config = self._get_config(key) + state = self.get_state(key) + + state.lifetime_input_tokens += input_tokens + state.lifetime_output_tokens += output_tokens + + if config is None: + self.save_state(key, state) + return + + state.events.append(UsageEvent(timestamp, input_tokens, output_tokens)) + state.window_input_tokens += input_tokens + state.window_output_tokens += output_tokens + + self._prune_expired_events(state, config, now=timestamp) + self.save_state(key, state) + + def is_rate_limited( + self, + model: str, + timestamp: datetime | None = None, + ) -> bool: + """Return True if the pair currently exceeds its limits.""" + + allowance = self.get_available_tokens(model, timestamp=timestamp) + if allowance is None: + return False + + limits = [ + allowance.input_tokens, + allowance.output_tokens, + allowance.total_tokens, + ] + return any(limit is not None and limit <= 0 for limit in limits) + + def get_available_tokens( + self, + model: str, + timestamp: datetime | None = None, + ) -> TokenAllowance | None: + """Return the current token allowance for the provider/model pair. + + If there is no configuration for the pair (or a default configuration), + ``None`` is returned to indicate that no limits are enforced. + """ + + split_model_key(model) + key = model + with self._lock: + config = self._get_config(key) + if config is None: + return None + + state = self.get_state(key) + self._prune_expired_events(state, config, now=timestamp) + self.save_state(key, state) + + if config.max_total_tokens is None: + total_remaining = None + else: + total_remaining = config.max_total_tokens - ( + state.window_input_tokens + state.window_output_tokens + ) + + if config.max_input_tokens is None: + input_remaining = None + else: + input_remaining = config.max_input_tokens - state.window_input_tokens + + if config.max_output_tokens is None: + output_remaining = None + else: + output_remaining = ( + config.max_output_tokens - state.window_output_tokens + ) + + return TokenAllowance( + input_tokens=clamp_non_negative(input_remaining), + output_tokens=clamp_non_negative(output_remaining), + total_tokens=clamp_non_negative(total_remaining), + ) + + def get_usage_breakdown( + self, provider: str | None = None, model: str | None = None + ) -> dict[str, dict[str, UsageBreakdown]]: + """Return usage statistics grouped by provider and model.""" + + with self._lock: + providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict) + for key, state in self.iter_state_items(): + prov, model_name = split_model_key(key) + if provider and provider != prov: + continue + if model and model != model_name: + continue + + window_total = state.window_input_tokens + state.window_output_tokens + breakdown = UsageBreakdown( + window_input_tokens=state.window_input_tokens, + window_output_tokens=state.window_output_tokens, + window_total_tokens=window_total, + lifetime_input_tokens=state.lifetime_input_tokens, + lifetime_output_tokens=state.lifetime_output_tokens, + ) + providers[prov][model_name] = breakdown + + return providers + + def iter_provider_totals(self) -> Iterable[tuple[str, UsageBreakdown]]: + """Yield aggregated totals for each provider across its models.""" + + breakdowns = self.get_usage_breakdown() + for provider, models in breakdowns.items(): + window_input = sum(b.window_input_tokens for b in models.values()) + window_output = sum(b.window_output_tokens for b in models.values()) + lifetime_input = sum(b.lifetime_input_tokens for b in models.values()) + lifetime_output = sum(b.lifetime_output_tokens for b in models.values()) + + yield ( + provider, + UsageBreakdown( + window_input_tokens=window_input, + window_output_tokens=window_output, + window_total_tokens=window_input + window_output, + lifetime_input_tokens=lifetime_input, + lifetime_output_tokens=lifetime_output, + ), + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _get_config(self, key: str) -> RateLimitConfig | None: + return self._configs.get(key) or self._default_config + + def _prune_expired_events( + self, + state: UsageState, + config: RateLimitConfig, + now: datetime | None = None, + ) -> None: + if not state.events: + return + + now = now or datetime.now(timezone.utc) + cutoff = now - config.window + + for event in tuple(state.events): + if event.timestamp > cutoff: + break + state.events.popleft() + state.window_input_tokens -= event.input_tokens + state.window_output_tokens -= event.output_tokens + + state.window_input_tokens = max(state.window_input_tokens, 0) + state.window_output_tokens = max(state.window_output_tokens, 0) + + +class InMemoryUsageTracker(UsageTracker): + """Tracks LLM usage for providers and models within a rolling window.""" + + def __init__( + self, + configs: dict[str, RateLimitConfig], + default_config: RateLimitConfig | None = None, + ) -> None: + super().__init__(configs=configs, default_config=default_config) + self._states: dict[str, UsageState] = {} + + def get_state(self, key: str) -> UsageState: + return self._states.setdefault(key, UsageState()) + + def iter_state_items(self) -> Iterable[tuple[str, UsageState]]: + return tuple(self._states.items()) + + +def clamp_non_negative(value: int | None) -> int | None: + if value is None: + return None + return 0 if value < 0 else value + diff --git a/tests/memory/common/llms/test_usage_tracker.py b/tests/memory/common/llms/test_usage_tracker.py new file mode 100644 index 0000000..e0df8ce --- /dev/null +++ b/tests/memory/common/llms/test_usage_tracker.py @@ -0,0 +1,147 @@ +from datetime import datetime, timedelta, timezone + +import pytest + +from memory.common.llms.usage_tracker import ( + InMemoryUsageTracker, + RateLimitConfig, + UsageTracker, +) + + +@pytest.fixture +def tracker() -> InMemoryUsageTracker: + config = RateLimitConfig( + window=timedelta(minutes=1), + max_input_tokens=1_000, + max_output_tokens=2_000, + max_total_tokens=2_500, + ) + return InMemoryUsageTracker( + { + "anthropic/claude-3": config, + "anthropic/haiku": config, + } + ) + + +@pytest.mark.parametrize( + "window, kwargs", + [ + (timedelta(minutes=1), {}), + (timedelta(seconds=0), {"max_total_tokens": 1}), + ], +) +def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None: + with pytest.raises(ValueError): + RateLimitConfig(window=window, **kwargs) + + +def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None: + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now) + + allowance = tracker.get_available_tokens( + "anthropic/claude-3", timestamp=now + ) + assert allowance is not None + assert allowance.input_tokens == 900 + assert allowance.output_tokens == 1_800 + assert allowance.total_tokens == 2_200 + + +def test_rate_limited_when_over_budget(tracker: InMemoryUsageTracker) -> None: + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now) + + assert tracker.is_rate_limited("anthropic/claude-3", timestamp=now) + + +def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None: + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now) + + later = now + timedelta(minutes=2) + allowance = tracker.get_available_tokens( + "anthropic/claude-3", timestamp=later + ) + assert allowance is not None + assert allowance.input_tokens == 1_000 + assert allowance.output_tokens == 2_000 + assert allowance.total_tokens == 2_500 + assert not tracker.is_rate_limited("anthropic/claude-3", timestamp=later) + + +def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None: + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now) + tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now) + + breakdown = tracker.get_usage_breakdown() + assert "anthropic" in breakdown + assert "claude-3" in breakdown["anthropic"] + claude_usage = breakdown["anthropic"]["claude-3"] + assert claude_usage.window_input_tokens == 100 + assert claude_usage.window_output_tokens == 200 + + provider_totals = dict(tracker.iter_provider_totals()) + anthropic_totals = provider_totals["anthropic"] + assert anthropic_totals.window_input_tokens == 150 + assert anthropic_totals.window_output_tokens == 275 + + +def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None: + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now) + tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now) + + filtered = tracker.get_usage_breakdown(provider="anthropic") + assert set(filtered.keys()) == {"anthropic"} + assert set(filtered["anthropic"].keys()) == {"claude-3"} + + filtered_model = tracker.get_usage_breakdown(model="gpt-4o") + assert set(filtered_model.keys()) == {"openai"} + assert set(filtered_model["openai"].keys()) == {"gpt-4o"} + + +def test_missing_configuration_records_lifetime_only() -> None: + tracker = InMemoryUsageTracker(configs={}) + tracker.record_usage("openai/gpt-4o", 10, 20) + + assert tracker.get_available_tokens("openai/gpt-4o") is None + + breakdown = tracker.get_usage_breakdown() + usage = breakdown["openai"]["gpt-4o"] + assert usage.window_input_tokens == 0 + assert usage.lifetime_input_tokens == 10 + + +def test_default_configuration_is_used() -> None: + default = RateLimitConfig(window=timedelta(minutes=1), max_total_tokens=100) + tracker = InMemoryUsageTracker(configs={}, default_config=default) + + tracker.record_usage("anthropic/claude-3", 10, 10) + allowance = tracker.get_available_tokens("anthropic/claude-3") + assert allowance is not None + assert allowance.total_tokens == 80 + + +def test_record_usage_rejects_negative_values(tracker: InMemoryUsageTracker) -> None: + with pytest.raises(ValueError): + tracker.record_usage("anthropic/claude-3", -1, 0) + + +def test_is_rate_limited_when_only_output_exceeds_limit() -> None: + config = RateLimitConfig(window=timedelta(minutes=1), max_output_tokens=50) + tracker = InMemoryUsageTracker({"openai/gpt-4o": config}) + + tracker.record_usage("openai/gpt-4o", 0, 50) + assert tracker.is_rate_limited("openai/gpt-4o") + + +def test_usage_tracker_base_not_instantiable() -> None: + class DummyTracker(UsageTracker): + pass + + with pytest.raises(NotImplementedError): + DummyTracker({}).record_usage("provider/model", 1, 1)