use usage tracker

This commit is contained in:
Daniel O'Connell 2025-11-01 18:49:06 +01:00
parent 8af07f0dac
commit 9639fa3dd7
8 changed files with 223 additions and 46 deletions

View File

@ -24,13 +24,6 @@ from memory.common.llms.base import (
) )
from memory.common.llms.anthropic_provider import AnthropicProvider from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.openai_provider import OpenAIProvider from memory.common.llms.openai_provider import OpenAIProvider
from memory.common.llms.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
TokenAllowance,
UsageBreakdown,
UsageTracker,
)
from memory.common import tokens from memory.common import tokens
__all__ = [ __all__ = [
@ -49,11 +42,6 @@ __all__ = [
"StreamEvent", "StreamEvent",
"LLMSettings", "LLMSettings",
"create_provider", "create_provider",
"InMemoryUsageTracker",
"RateLimitConfig",
"TokenAllowance",
"UsageBreakdown",
"UsageTracker",
] ]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -93,28 +81,3 @@ def truncate(content: str, target_tokens: int) -> str:
if len(content) > target_chars: if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..." return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content return content
# bla = 1
# from memory.common.llms import *
# from memory.common.llms.tools.discord import make_discord_tools
# from memory.common.db.connection import make_session
# from memory.common.db.models import *
# model = "anthropic/claude-sonnet-4-5"
# provider = create_provider(model=model)
# with make_session() as session:
# bot = session.query(DiscordBotUser).first()
# server = session.query(DiscordServer).first()
# channel = server.channels[0]
# tools = make_discord_tools(bot, None, channel, model)
# def demo(msg: str):
# messages = [
# Message(
# role=MessageRole.USER,
# content=msg,
# )
# ]
# for m in provider.stream_with_tools(messages, tools):
# print(m)

View File

@ -12,6 +12,7 @@ from PIL import Image
from memory.common import settings from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -204,7 +205,9 @@ class LLMSettings:
class BaseLLMProvider(ABC): class BaseLLMProvider(ABC):
"""Base class for LLM providers.""" """Base class for LLM providers."""
def __init__(self, api_key: str, model: str): def __init__(
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
):
""" """
Initialize the LLM provider. Initialize the LLM provider.
@ -215,6 +218,7 @@ class BaseLLMProvider(ABC):
self.api_key = api_key self.api_key = api_key
self.model = model self.model = model
self._client: Any = None self._client: Any = None
self.usage_tracker: UsageTracker = usage_tracker or RedisUsageTracker()
@abstractmethod @abstractmethod
def _initialize_client(self) -> Any: def _initialize_client(self) -> Any:

View File

@ -0,0 +1,17 @@
from memory.common.llms.usage.redis_usage_tracker import RedisUsageTracker
from memory.common.llms.usage.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
TokenAllowance,
UsageBreakdown,
UsageTracker,
)
__all__ = [
"InMemoryUsageTracker",
"RateLimitConfig",
"RedisUsageTracker",
"TokenAllowance",
"UsageBreakdown",
"UsageTracker",
]

View File

@ -0,0 +1,88 @@
"""Redis-backed usage tracker implementation."""
import json
from typing import Any, Iterable, Protocol
import redis
from memory.common import settings
from memory.common.llms.usage.usage_tracker import (
RateLimitConfig,
UsageState,
UsageTracker,
)
class RedisClientProtocol(Protocol):
def get(self, key: str) -> Any: # pragma: no cover - Protocol definition
...
def set(
self, key: str, value: Any
) -> Any: # pragma: no cover - Protocol definition
...
def scan_iter(
self, match: str
) -> Iterable[Any]: # pragma: no cover - Protocol definition
...
class RedisUsageTracker(UsageTracker):
"""Tracks LLM usage for providers and models using Redis for persistence."""
def __init__(
self,
configs: dict[str, RateLimitConfig],
default_config: RateLimitConfig | None = None,
*,
redis_client: RedisClientProtocol | None = None,
key_prefix: str | None = None,
) -> None:
super().__init__(configs=configs, default_config=default_config)
if redis_client is None:
redis_client = redis.Redis(
host=settings.REDIS_HOST,
port=int(settings.REDIS_PORT),
db=int(settings.REDIS_DB),
decode_responses=False,
)
self._redis = redis_client
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
self._key_prefix = prefix.rstrip(":")
def get_state(self, model: str) -> UsageState:
redis_key = self._format_key(model)
payload = self._redis.get(redis_key)
if not payload:
return UsageState()
if isinstance(payload, bytes):
payload = payload.decode()
return UsageState.from_payload(json.loads(payload))
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
pattern = f"{self._key_prefix}:*"
for redis_key in self._redis.scan_iter(match=pattern):
key = self._ensure_str(redis_key)
payload = self._redis.get(key)
if not payload:
continue
if isinstance(payload, bytes):
payload = payload.decode()
state = UsageState.from_payload(json.loads(payload))
yield key[len(self._key_prefix) + 1 :], state
def save_state(self, model: str, state: UsageState) -> None:
redis_key = self._format_key(model)
self._redis.set(
redis_key, json.dumps(state.to_payload(), separators=(",", ":"))
)
def _format_key(self, model: str) -> str:
return f"{self._key_prefix}:{model}"
@staticmethod
def _ensure_str(value: Any) -> str:
if isinstance(value, bytes):
return value.decode()
return str(value)

View File

@ -5,6 +5,7 @@ from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from threading import Lock from threading import Lock
from typing import Any
@dataclass(frozen=True) @dataclass(frozen=True)
@ -45,6 +46,40 @@ class UsageState:
lifetime_input_tokens: int = 0 lifetime_input_tokens: int = 0
lifetime_output_tokens: int = 0 lifetime_output_tokens: int = 0
def to_payload(self) -> dict[str, Any]:
return {
"events": [
{
"timestamp": event.timestamp.isoformat(),
"input_tokens": event.input_tokens,
"output_tokens": event.output_tokens,
}
for event in self.events
],
"window_input_tokens": self.window_input_tokens,
"window_output_tokens": self.window_output_tokens,
"lifetime_input_tokens": self.lifetime_input_tokens,
"lifetime_output_tokens": self.lifetime_output_tokens,
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> "UsageState":
events = deque(
UsageEvent(
timestamp=datetime.fromisoformat(event["timestamp"]),
input_tokens=event["input_tokens"],
output_tokens=event["output_tokens"],
)
for event in payload.get("events", [])
)
return cls(
events=events,
window_input_tokens=payload.get("window_input_tokens", 0),
window_output_tokens=payload.get("window_output_tokens", 0),
lifetime_input_tokens=payload.get("lifetime_input_tokens", 0),
lifetime_output_tokens=payload.get("lifetime_output_tokens", 0),
)
@dataclass @dataclass
class TokenAllowance: class TokenAllowance:
@ -76,9 +111,7 @@ class UsageBreakdown:
def split_model_key(model: str) -> tuple[str, str]: def split_model_key(model: str) -> tuple[str, str]:
if "/" not in model: if "/" not in model:
raise ValueError( raise ValueError("model must be formatted as '<provider>/<model_name>'")
"model must be formatted as '<provider>/<model_name>'"
)
provider, model_name = model.split("/", maxsplit=1) provider, model_name = model.split("/", maxsplit=1)
if not provider or not model_name: if not provider or not model_name:
@ -205,9 +238,7 @@ class UsageTracker:
if config.max_output_tokens is None: if config.max_output_tokens is None:
output_remaining = None output_remaining = None
else: else:
output_remaining = ( output_remaining = config.max_output_tokens - state.window_output_tokens
config.max_output_tokens - state.window_output_tokens
)
return TokenAllowance( return TokenAllowance(
input_tokens=clamp_non_negative(input_remaining), input_tokens=clamp_non_negative(input_remaining),
@ -313,4 +344,3 @@ def clamp_non_negative(value: int | None) -> int | None:
if value is None: if value is None:
return None return None
return 0 if value < 0 else value return 0 if value < 0 else value

View File

@ -38,6 +38,7 @@ CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
REDIS_HOST = os.getenv("REDIS_HOST", "redis") REDIS_HOST = os.getenv("REDIS_HOST", "redis")
REDIS_PORT = os.getenv("REDIS_PORT", "6379") REDIS_PORT = os.getenv("REDIS_PORT", "6379")
REDIS_DB = os.getenv("REDIS_DB", "0") REDIS_DB = os.getenv("REDIS_DB", "0")
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
CELERY_BROKER_USER = os.getenv( CELERY_BROKER_USER = os.getenv(
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else "" "CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
) )

View File

@ -58,6 +58,13 @@ def call_llm(
msgs: list[str] = [], msgs: list[str] = [],
allowed_tools: list[str] = [], allowed_tools: list[str] = [],
) -> str | None: ) -> str | None:
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
tools = make_discord_tools( tools = make_discord_tools(
message.recipient_user.system_user, message.recipient_user.system_user,
message.from_user, message.from_user,
@ -73,7 +80,6 @@ def call_llm(
system_prompt += comm_channel_prompt( system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel session, message.recipient_user, message.channel
) )
provider = create_provider(model=model)
messages = previous_messages( messages = previous_messages(
session, session,
message.recipient_user and message.recipient_user.id, message.recipient_user and message.recipient_user.id,

View File

@ -1,7 +1,24 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Iterable
import pytest import pytest
try:
import redis # noqa: F401 # pragma: no cover - optional test dependency
except ModuleNotFoundError: # pragma: no cover - import guard for test envs
import sys
from types import SimpleNamespace
class _RedisStub(SimpleNamespace):
class Redis: # type: ignore[no-redef]
def __init__(self, *args: object, **kwargs: object) -> None:
raise ModuleNotFoundError(
"The 'redis' package is required to use RedisUsageTracker"
)
sys.modules.setdefault("redis", _RedisStub())
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
from memory.common.llms.usage_tracker import ( from memory.common.llms.usage_tracker import (
InMemoryUsageTracker, InMemoryUsageTracker,
RateLimitConfig, RateLimitConfig,
@ -9,6 +26,24 @@ from memory.common.llms.usage_tracker import (
) )
class FakeRedis:
def __init__(self) -> None:
self._store: dict[str, str] = {}
def get(self, key: str) -> str | None:
return self._store.get(key)
def set(self, key: str, value: str) -> None:
self._store[key] = value
def scan_iter(self, match: str) -> Iterable[str]:
from fnmatch import fnmatch
for key in list(self._store.keys()):
if fnmatch(key, match):
yield key
@pytest.fixture @pytest.fixture
def tracker() -> InMemoryUsageTracker: def tracker() -> InMemoryUsageTracker:
config = RateLimitConfig( config = RateLimitConfig(
@ -25,6 +60,23 @@ def tracker() -> InMemoryUsageTracker:
) )
@pytest.fixture
def redis_tracker() -> RedisUsageTracker:
config = RateLimitConfig(
window=timedelta(minutes=1),
max_input_tokens=1_000,
max_output_tokens=2_000,
max_total_tokens=2_500,
)
return RedisUsageTracker(
{
"anthropic/claude-3": config,
"anthropic/haiku": config,
},
redis_client=FakeRedis(),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"window, kwargs", "window, kwargs",
[ [
@ -139,6 +191,22 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
assert tracker.is_rate_limited("openai/gpt-4o") assert tracker.is_rate_limited("openai/gpt-4o")
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
allowance = redis_tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
assert allowance is not None
assert allowance.input_tokens == 900
breakdown = redis_tracker.get_usage_breakdown()
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
items = dict(redis_tracker.iter_state_items())
assert set(items.keys()) == {"anthropic/claude-3", "anthropic/haiku"}
def test_usage_tracker_base_not_instantiable() -> None: def test_usage_tracker_base_not_instantiable() -> None:
class DummyTracker(UsageTracker): class DummyTracker(UsageTracker):
pass pass