mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
Base usage tracker
This commit is contained in:
parent
bcb470db9b
commit
07852f9ee7
5
AGENTS.md
Normal file
5
AGENTS.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Agent Guidance
|
||||||
|
|
||||||
|
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
||||||
|
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
|
||||||
|
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
||||||
@ -24,6 +24,13 @@ from memory.common.llms.base import (
|
|||||||
)
|
)
|
||||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
from memory.common.llms.openai_provider import OpenAIProvider
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
|
from memory.common.llms.usage_tracker import (
|
||||||
|
InMemoryUsageTracker,
|
||||||
|
RateLimitConfig,
|
||||||
|
TokenAllowance,
|
||||||
|
UsageBreakdown,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
from memory.common import tokens
|
from memory.common import tokens
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -42,6 +49,11 @@ __all__ = [
|
|||||||
"StreamEvent",
|
"StreamEvent",
|
||||||
"LLMSettings",
|
"LLMSettings",
|
||||||
"create_provider",
|
"create_provider",
|
||||||
|
"InMemoryUsageTracker",
|
||||||
|
"RateLimitConfig",
|
||||||
|
"TokenAllowance",
|
||||||
|
"UsageBreakdown",
|
||||||
|
"UsageTracker",
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
316
src/memory/common/llms/usage_tracker.py
Normal file
316
src/memory/common/llms/usage_tracker.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
"""LLM usage tracking utilities."""
|
||||||
|
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""Configuration for a single rolling usage window."""
|
||||||
|
|
||||||
|
window: timedelta
|
||||||
|
max_input_tokens: int | None = None
|
||||||
|
max_output_tokens: int | None = None
|
||||||
|
max_total_tokens: int | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.window <= timedelta(0):
|
||||||
|
raise ValueError("window must be positive")
|
||||||
|
if (
|
||||||
|
self.max_input_tokens is None
|
||||||
|
and self.max_output_tokens is None
|
||||||
|
and self.max_total_tokens is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of max_input_tokens, max_output_tokens or "
|
||||||
|
"max_total_tokens must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageEvent:
|
||||||
|
timestamp: datetime
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageState:
|
||||||
|
events: deque[UsageEvent] = field(default_factory=deque)
|
||||||
|
window_input_tokens: int = 0
|
||||||
|
window_output_tokens: int = 0
|
||||||
|
lifetime_input_tokens: int = 0
|
||||||
|
lifetime_output_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenAllowance:
|
||||||
|
"""Represents the tokens that can be consumed right now."""
|
||||||
|
|
||||||
|
input_tokens: int | None
|
||||||
|
output_tokens: int | None
|
||||||
|
total_tokens: int | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageBreakdown:
|
||||||
|
"""Detailed usage statistics for a provider/model pair."""
|
||||||
|
|
||||||
|
window_input_tokens: int
|
||||||
|
window_output_tokens: int
|
||||||
|
window_total_tokens: int
|
||||||
|
lifetime_input_tokens: int
|
||||||
|
lifetime_output_tokens: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def window_total(self) -> int:
|
||||||
|
return self.window_total_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lifetime_total_tokens(self) -> int:
|
||||||
|
return self.lifetime_input_tokens + self.lifetime_output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def split_model_key(model: str) -> tuple[str, str]:
|
||||||
|
if "/" not in model:
|
||||||
|
raise ValueError(
|
||||||
|
"model must be formatted as '<provider>/<model_name>'"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider, model_name = model.split("/", maxsplit=1)
|
||||||
|
if not provider or not model_name:
|
||||||
|
raise ValueError(
|
||||||
|
"model must include both provider and model name separated by '/'"
|
||||||
|
)
|
||||||
|
return provider, model_name
|
||||||
|
|
||||||
|
|
||||||
|
class UsageTracker:
|
||||||
|
"""Base class for usage trackers that operate on provider/model keys."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: dict[str, RateLimitConfig],
|
||||||
|
default_config: RateLimitConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._configs = configs
|
||||||
|
self._default_config = default_config
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Storage hooks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def get_state(self, key: str) -> UsageState:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save_state(self, key: str, state: UsageState) -> None:
|
||||||
|
"""Persist the given state back to the underlying store."""
|
||||||
|
del key, state
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def record_usage(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Record token usage for the given provider/model pair."""
|
||||||
|
|
||||||
|
if input_tokens < 0 or output_tokens < 0:
|
||||||
|
raise ValueError("Token counts must be non-negative")
|
||||||
|
|
||||||
|
timestamp = timestamp or datetime.now(timezone.utc)
|
||||||
|
split_model_key(model)
|
||||||
|
key = model
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
config = self._get_config(key)
|
||||||
|
state = self.get_state(key)
|
||||||
|
|
||||||
|
state.lifetime_input_tokens += input_tokens
|
||||||
|
state.lifetime_output_tokens += output_tokens
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
self.save_state(key, state)
|
||||||
|
return
|
||||||
|
|
||||||
|
state.events.append(UsageEvent(timestamp, input_tokens, output_tokens))
|
||||||
|
state.window_input_tokens += input_tokens
|
||||||
|
state.window_output_tokens += output_tokens
|
||||||
|
|
||||||
|
self._prune_expired_events(state, config, now=timestamp)
|
||||||
|
self.save_state(key, state)
|
||||||
|
|
||||||
|
def is_rate_limited(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if the pair currently exceeds its limits."""
|
||||||
|
|
||||||
|
allowance = self.get_available_tokens(model, timestamp=timestamp)
|
||||||
|
if allowance is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
limits = [
|
||||||
|
allowance.input_tokens,
|
||||||
|
allowance.output_tokens,
|
||||||
|
allowance.total_tokens,
|
||||||
|
]
|
||||||
|
return any(limit is not None and limit <= 0 for limit in limits)
|
||||||
|
|
||||||
|
def get_available_tokens(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> TokenAllowance | None:
|
||||||
|
"""Return the current token allowance for the provider/model pair.
|
||||||
|
|
||||||
|
If there is no configuration for the pair (or a default configuration),
|
||||||
|
``None`` is returned to indicate that no limits are enforced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
split_model_key(model)
|
||||||
|
key = model
|
||||||
|
with self._lock:
|
||||||
|
config = self._get_config(key)
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
state = self.get_state(key)
|
||||||
|
self._prune_expired_events(state, config, now=timestamp)
|
||||||
|
self.save_state(key, state)
|
||||||
|
|
||||||
|
if config.max_total_tokens is None:
|
||||||
|
total_remaining = None
|
||||||
|
else:
|
||||||
|
total_remaining = config.max_total_tokens - (
|
||||||
|
state.window_input_tokens + state.window_output_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.max_input_tokens is None:
|
||||||
|
input_remaining = None
|
||||||
|
else:
|
||||||
|
input_remaining = config.max_input_tokens - state.window_input_tokens
|
||||||
|
|
||||||
|
if config.max_output_tokens is None:
|
||||||
|
output_remaining = None
|
||||||
|
else:
|
||||||
|
output_remaining = (
|
||||||
|
config.max_output_tokens - state.window_output_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return TokenAllowance(
|
||||||
|
input_tokens=clamp_non_negative(input_remaining),
|
||||||
|
output_tokens=clamp_non_negative(output_remaining),
|
||||||
|
total_tokens=clamp_non_negative(total_remaining),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_usage_breakdown(
|
||||||
|
self, provider: str | None = None, model: str | None = None
|
||||||
|
) -> dict[str, dict[str, UsageBreakdown]]:
|
||||||
|
"""Return usage statistics grouped by provider and model."""
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||||
|
for key, state in self.iter_state_items():
|
||||||
|
prov, model_name = split_model_key(key)
|
||||||
|
if provider and provider != prov:
|
||||||
|
continue
|
||||||
|
if model and model != model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
window_total = state.window_input_tokens + state.window_output_tokens
|
||||||
|
breakdown = UsageBreakdown(
|
||||||
|
window_input_tokens=state.window_input_tokens,
|
||||||
|
window_output_tokens=state.window_output_tokens,
|
||||||
|
window_total_tokens=window_total,
|
||||||
|
lifetime_input_tokens=state.lifetime_input_tokens,
|
||||||
|
lifetime_output_tokens=state.lifetime_output_tokens,
|
||||||
|
)
|
||||||
|
providers[prov][model_name] = breakdown
|
||||||
|
|
||||||
|
return providers
|
||||||
|
|
||||||
|
def iter_provider_totals(self) -> Iterable[tuple[str, UsageBreakdown]]:
|
||||||
|
"""Yield aggregated totals for each provider across its models."""
|
||||||
|
|
||||||
|
breakdowns = self.get_usage_breakdown()
|
||||||
|
for provider, models in breakdowns.items():
|
||||||
|
window_input = sum(b.window_input_tokens for b in models.values())
|
||||||
|
window_output = sum(b.window_output_tokens for b in models.values())
|
||||||
|
lifetime_input = sum(b.lifetime_input_tokens for b in models.values())
|
||||||
|
lifetime_output = sum(b.lifetime_output_tokens for b in models.values())
|
||||||
|
|
||||||
|
yield (
|
||||||
|
provider,
|
||||||
|
UsageBreakdown(
|
||||||
|
window_input_tokens=window_input,
|
||||||
|
window_output_tokens=window_output,
|
||||||
|
window_total_tokens=window_input + window_output,
|
||||||
|
lifetime_input_tokens=lifetime_input,
|
||||||
|
lifetime_output_tokens=lifetime_output,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _get_config(self, key: str) -> RateLimitConfig | None:
|
||||||
|
return self._configs.get(key) or self._default_config
|
||||||
|
|
||||||
|
def _prune_expired_events(
|
||||||
|
self,
|
||||||
|
state: UsageState,
|
||||||
|
config: RateLimitConfig,
|
||||||
|
now: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
if not state.events:
|
||||||
|
return
|
||||||
|
|
||||||
|
now = now or datetime.now(timezone.utc)
|
||||||
|
cutoff = now - config.window
|
||||||
|
|
||||||
|
for event in tuple(state.events):
|
||||||
|
if event.timestamp > cutoff:
|
||||||
|
break
|
||||||
|
state.events.popleft()
|
||||||
|
state.window_input_tokens -= event.input_tokens
|
||||||
|
state.window_output_tokens -= event.output_tokens
|
||||||
|
|
||||||
|
state.window_input_tokens = max(state.window_input_tokens, 0)
|
||||||
|
state.window_output_tokens = max(state.window_output_tokens, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryUsageTracker(UsageTracker):
|
||||||
|
"""Tracks LLM usage for providers and models within a rolling window."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: dict[str, RateLimitConfig],
|
||||||
|
default_config: RateLimitConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(configs=configs, default_config=default_config)
|
||||||
|
self._states: dict[str, UsageState] = {}
|
||||||
|
|
||||||
|
def get_state(self, key: str) -> UsageState:
|
||||||
|
return self._states.setdefault(key, UsageState())
|
||||||
|
|
||||||
|
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||||
|
return tuple(self._states.items())
|
||||||
|
|
||||||
|
|
||||||
|
def clamp_non_negative(value: int | None) -> int | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return 0 if value < 0 else value
|
||||||
|
|
||||||
147
tests/memory/common/llms/test_usage_tracker.py
Normal file
147
tests/memory/common/llms/test_usage_tracker.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from memory.common.llms.usage_tracker import (
|
||||||
|
InMemoryUsageTracker,
|
||||||
|
RateLimitConfig,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tracker() -> InMemoryUsageTracker:
|
||||||
|
config = RateLimitConfig(
|
||||||
|
window=timedelta(minutes=1),
|
||||||
|
max_input_tokens=1_000,
|
||||||
|
max_output_tokens=2_000,
|
||||||
|
max_total_tokens=2_500,
|
||||||
|
)
|
||||||
|
return InMemoryUsageTracker(
|
||||||
|
{
|
||||||
|
"anthropic/claude-3": config,
|
||||||
|
"anthropic/haiku": config,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"window, kwargs",
|
||||||
|
[
|
||||||
|
(timedelta(minutes=1), {}),
|
||||||
|
(timedelta(seconds=0), {"max_total_tokens": 1}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RateLimitConfig(window=window, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
|
|
||||||
|
allowance = tracker.get_available_tokens(
|
||||||
|
"anthropic/claude-3", timestamp=now
|
||||||
|
)
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.input_tokens == 900
|
||||||
|
assert allowance.output_tokens == 1_800
|
||||||
|
assert allowance.total_tokens == 2_200
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limited_when_over_budget(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||||
|
|
||||||
|
assert tracker.is_rate_limited("anthropic/claude-3", timestamp=now)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||||
|
|
||||||
|
later = now + timedelta(minutes=2)
|
||||||
|
allowance = tracker.get_available_tokens(
|
||||||
|
"anthropic/claude-3", timestamp=later
|
||||||
|
)
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.input_tokens == 1_000
|
||||||
|
assert allowance.output_tokens == 2_000
|
||||||
|
assert allowance.total_tokens == 2_500
|
||||||
|
assert not tracker.is_rate_limited("anthropic/claude-3", timestamp=later)
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
|
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||||
|
|
||||||
|
breakdown = tracker.get_usage_breakdown()
|
||||||
|
assert "anthropic" in breakdown
|
||||||
|
assert "claude-3" in breakdown["anthropic"]
|
||||||
|
claude_usage = breakdown["anthropic"]["claude-3"]
|
||||||
|
assert claude_usage.window_input_tokens == 100
|
||||||
|
assert claude_usage.window_output_tokens == 200
|
||||||
|
|
||||||
|
provider_totals = dict(tracker.iter_provider_totals())
|
||||||
|
anthropic_totals = provider_totals["anthropic"]
|
||||||
|
assert anthropic_totals.window_input_tokens == 150
|
||||||
|
assert anthropic_totals.window_output_tokens == 275
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
|
||||||
|
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
|
||||||
|
|
||||||
|
filtered = tracker.get_usage_breakdown(provider="anthropic")
|
||||||
|
assert set(filtered.keys()) == {"anthropic"}
|
||||||
|
assert set(filtered["anthropic"].keys()) == {"claude-3"}
|
||||||
|
|
||||||
|
filtered_model = tracker.get_usage_breakdown(model="gpt-4o")
|
||||||
|
assert set(filtered_model.keys()) == {"openai"}
|
||||||
|
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_configuration_records_lifetime_only() -> None:
|
||||||
|
tracker = InMemoryUsageTracker(configs={})
|
||||||
|
tracker.record_usage("openai/gpt-4o", 10, 20)
|
||||||
|
|
||||||
|
assert tracker.get_available_tokens("openai/gpt-4o") is None
|
||||||
|
|
||||||
|
breakdown = tracker.get_usage_breakdown()
|
||||||
|
usage = breakdown["openai"]["gpt-4o"]
|
||||||
|
assert usage.window_input_tokens == 0
|
||||||
|
assert usage.lifetime_input_tokens == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_configuration_is_used() -> None:
|
||||||
|
default = RateLimitConfig(window=timedelta(minutes=1), max_total_tokens=100)
|
||||||
|
tracker = InMemoryUsageTracker(configs={}, default_config=default)
|
||||||
|
|
||||||
|
tracker.record_usage("anthropic/claude-3", 10, 10)
|
||||||
|
allowance = tracker.get_available_tokens("anthropic/claude-3")
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.total_tokens == 80
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_usage_rejects_negative_values(tracker: InMemoryUsageTracker) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tracker.record_usage("anthropic/claude-3", -1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
||||||
|
config = RateLimitConfig(window=timedelta(minutes=1), max_output_tokens=50)
|
||||||
|
tracker = InMemoryUsageTracker({"openai/gpt-4o": config})
|
||||||
|
|
||||||
|
tracker.record_usage("openai/gpt-4o", 0, 50)
|
||||||
|
assert tracker.is_rate_limited("openai/gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_tracker_base_not_instantiable() -> None:
|
||||||
|
class DummyTracker(UsageTracker):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
DummyTracker({}).record_usage("provider/model", 1, 1)
|
||||||
Loading…
x
Reference in New Issue
Block a user