mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
use usage tracker
This commit is contained in:
parent
8af07f0dac
commit
9639fa3dd7
@ -24,13 +24,6 @@ from memory.common.llms.base import (
|
|||||||
)
|
)
|
||||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
from memory.common.llms.openai_provider import OpenAIProvider
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
from memory.common.llms.usage_tracker import (
|
|
||||||
InMemoryUsageTracker,
|
|
||||||
RateLimitConfig,
|
|
||||||
TokenAllowance,
|
|
||||||
UsageBreakdown,
|
|
||||||
UsageTracker,
|
|
||||||
)
|
|
||||||
from memory.common import tokens
|
from memory.common import tokens
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -49,11 +42,6 @@ __all__ = [
|
|||||||
"StreamEvent",
|
"StreamEvent",
|
||||||
"LLMSettings",
|
"LLMSettings",
|
||||||
"create_provider",
|
"create_provider",
|
||||||
"InMemoryUsageTracker",
|
|
||||||
"RateLimitConfig",
|
|
||||||
"TokenAllowance",
|
|
||||||
"UsageBreakdown",
|
|
||||||
"UsageTracker",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -93,28 +81,3 @@ def truncate(content: str, target_tokens: int) -> str:
|
|||||||
if len(content) > target_chars:
|
if len(content) > target_chars:
|
||||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
# bla = 1
|
|
||||||
# from memory.common.llms import *
|
|
||||||
# from memory.common.llms.tools.discord import make_discord_tools
|
|
||||||
# from memory.common.db.connection import make_session
|
|
||||||
# from memory.common.db.models import *
|
|
||||||
|
|
||||||
# model = "anthropic/claude-sonnet-4-5"
|
|
||||||
# provider = create_provider(model=model)
|
|
||||||
# with make_session() as session:
|
|
||||||
# bot = session.query(DiscordBotUser).first()
|
|
||||||
# server = session.query(DiscordServer).first()
|
|
||||||
# channel = server.channels[0]
|
|
||||||
# tools = make_discord_tools(bot, None, channel, model)
|
|
||||||
|
|
||||||
# def demo(msg: str):
|
|
||||||
# messages = [
|
|
||||||
# Message(
|
|
||||||
# role=MessageRole.USER,
|
|
||||||
# content=msg,
|
|
||||||
# )
|
|
||||||
# ]
|
|
||||||
# for m in provider.stream_with_tools(messages, tools):
|
|
||||||
# print(m)
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
||||||
|
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -204,7 +205,9 @@ class LLMSettings:
|
|||||||
class BaseLLMProvider(ABC):
|
class BaseLLMProvider(ABC):
|
||||||
"""Base class for LLM providers."""
|
"""Base class for LLM providers."""
|
||||||
|
|
||||||
def __init__(self, api_key: str, model: str):
|
def __init__(
|
||||||
|
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the LLM provider.
|
Initialize the LLM provider.
|
||||||
|
|
||||||
@ -215,6 +218,7 @@ class BaseLLMProvider(ABC):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self._client: Any = None
|
self._client: Any = None
|
||||||
|
self.usage_tracker: UsageTracker = usage_tracker or RedisUsageTracker()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _initialize_client(self) -> Any:
|
def _initialize_client(self) -> Any:
|
||||||
|
|||||||
17
src/memory/common/llms/usage/__init__.py
Normal file
17
src/memory/common/llms/usage/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from memory.common.llms.usage.redis_usage_tracker import RedisUsageTracker
|
||||||
|
from memory.common.llms.usage.usage_tracker import (
|
||||||
|
InMemoryUsageTracker,
|
||||||
|
RateLimitConfig,
|
||||||
|
TokenAllowance,
|
||||||
|
UsageBreakdown,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"InMemoryUsageTracker",
|
||||||
|
"RateLimitConfig",
|
||||||
|
"RedisUsageTracker",
|
||||||
|
"TokenAllowance",
|
||||||
|
"UsageBreakdown",
|
||||||
|
"UsageTracker",
|
||||||
|
]
|
||||||
88
src/memory/common/llms/usage/redis_usage_tracker.py
Normal file
88
src/memory/common/llms/usage/redis_usage_tracker.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
"""Redis-backed usage tracker implementation."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Iterable, Protocol
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.llms.usage.usage_tracker import (
|
||||||
|
RateLimitConfig,
|
||||||
|
UsageState,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisClientProtocol(Protocol):
|
||||||
|
def get(self, key: str) -> Any: # pragma: no cover - Protocol definition
|
||||||
|
...
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self, key: str, value: Any
|
||||||
|
) -> Any: # pragma: no cover - Protocol definition
|
||||||
|
...
|
||||||
|
|
||||||
|
def scan_iter(
|
||||||
|
self, match: str
|
||||||
|
) -> Iterable[Any]: # pragma: no cover - Protocol definition
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class RedisUsageTracker(UsageTracker):
|
||||||
|
"""Tracks LLM usage for providers and models using Redis for persistence."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: dict[str, RateLimitConfig],
|
||||||
|
default_config: RateLimitConfig | None = None,
|
||||||
|
*,
|
||||||
|
redis_client: RedisClientProtocol | None = None,
|
||||||
|
key_prefix: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(configs=configs, default_config=default_config)
|
||||||
|
if redis_client is None:
|
||||||
|
redis_client = redis.Redis(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=int(settings.REDIS_PORT),
|
||||||
|
db=int(settings.REDIS_DB),
|
||||||
|
decode_responses=False,
|
||||||
|
)
|
||||||
|
self._redis = redis_client
|
||||||
|
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
|
||||||
|
self._key_prefix = prefix.rstrip(":")
|
||||||
|
|
||||||
|
def get_state(self, model: str) -> UsageState:
|
||||||
|
redis_key = self._format_key(model)
|
||||||
|
payload = self._redis.get(redis_key)
|
||||||
|
if not payload:
|
||||||
|
return UsageState()
|
||||||
|
if isinstance(payload, bytes):
|
||||||
|
payload = payload.decode()
|
||||||
|
return UsageState.from_payload(json.loads(payload))
|
||||||
|
|
||||||
|
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||||
|
pattern = f"{self._key_prefix}:*"
|
||||||
|
for redis_key in self._redis.scan_iter(match=pattern):
|
||||||
|
key = self._ensure_str(redis_key)
|
||||||
|
payload = self._redis.get(key)
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
if isinstance(payload, bytes):
|
||||||
|
payload = payload.decode()
|
||||||
|
state = UsageState.from_payload(json.loads(payload))
|
||||||
|
yield key[len(self._key_prefix) + 1 :], state
|
||||||
|
|
||||||
|
def save_state(self, model: str, state: UsageState) -> None:
|
||||||
|
redis_key = self._format_key(model)
|
||||||
|
self._redis.set(
|
||||||
|
redis_key, json.dumps(state.to_payload(), separators=(",", ":"))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_key(self, model: str) -> str:
|
||||||
|
return f"{self._key_prefix}:{model}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_str(value: Any) -> str:
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return value.decode()
|
||||||
|
return str(value)
|
||||||
@ -5,6 +5,7 @@ from collections.abc import Iterable
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -45,6 +46,40 @@ class UsageState:
|
|||||||
lifetime_input_tokens: int = 0
|
lifetime_input_tokens: int = 0
|
||||||
lifetime_output_tokens: int = 0
|
lifetime_output_tokens: int = 0
|
||||||
|
|
||||||
|
def to_payload(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"timestamp": event.timestamp.isoformat(),
|
||||||
|
"input_tokens": event.input_tokens,
|
||||||
|
"output_tokens": event.output_tokens,
|
||||||
|
}
|
||||||
|
for event in self.events
|
||||||
|
],
|
||||||
|
"window_input_tokens": self.window_input_tokens,
|
||||||
|
"window_output_tokens": self.window_output_tokens,
|
||||||
|
"lifetime_input_tokens": self.lifetime_input_tokens,
|
||||||
|
"lifetime_output_tokens": self.lifetime_output_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_payload(cls, payload: dict[str, Any]) -> "UsageState":
|
||||||
|
events = deque(
|
||||||
|
UsageEvent(
|
||||||
|
timestamp=datetime.fromisoformat(event["timestamp"]),
|
||||||
|
input_tokens=event["input_tokens"],
|
||||||
|
output_tokens=event["output_tokens"],
|
||||||
|
)
|
||||||
|
for event in payload.get("events", [])
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
events=events,
|
||||||
|
window_input_tokens=payload.get("window_input_tokens", 0),
|
||||||
|
window_output_tokens=payload.get("window_output_tokens", 0),
|
||||||
|
lifetime_input_tokens=payload.get("lifetime_input_tokens", 0),
|
||||||
|
lifetime_output_tokens=payload.get("lifetime_output_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenAllowance:
|
class TokenAllowance:
|
||||||
@ -76,9 +111,7 @@ class UsageBreakdown:
|
|||||||
|
|
||||||
def split_model_key(model: str) -> tuple[str, str]:
|
def split_model_key(model: str) -> tuple[str, str]:
|
||||||
if "/" not in model:
|
if "/" not in model:
|
||||||
raise ValueError(
|
raise ValueError("model must be formatted as '<provider>/<model_name>'")
|
||||||
"model must be formatted as '<provider>/<model_name>'"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider, model_name = model.split("/", maxsplit=1)
|
provider, model_name = model.split("/", maxsplit=1)
|
||||||
if not provider or not model_name:
|
if not provider or not model_name:
|
||||||
@ -205,9 +238,7 @@ class UsageTracker:
|
|||||||
if config.max_output_tokens is None:
|
if config.max_output_tokens is None:
|
||||||
output_remaining = None
|
output_remaining = None
|
||||||
else:
|
else:
|
||||||
output_remaining = (
|
output_remaining = config.max_output_tokens - state.window_output_tokens
|
||||||
config.max_output_tokens - state.window_output_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
return TokenAllowance(
|
return TokenAllowance(
|
||||||
input_tokens=clamp_non_negative(input_remaining),
|
input_tokens=clamp_non_negative(input_remaining),
|
||||||
@ -313,4 +344,3 @@ def clamp_non_negative(value: int | None) -> int | None:
|
|||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return 0 if value < 0 else value
|
return 0 if value < 0 else value
|
||||||
|
|
||||||
@ -38,6 +38,7 @@ CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
|||||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||||
|
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
|
||||||
CELERY_BROKER_USER = os.getenv(
|
CELERY_BROKER_USER = os.getenv(
|
||||||
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
||||||
)
|
)
|
||||||
|
|||||||
@ -58,6 +58,13 @@ def call_llm(
|
|||||||
msgs: list[str] = [],
|
msgs: list[str] = [],
|
||||||
allowed_tools: list[str] = [],
|
allowed_tools: list[str] = [],
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
|
provider = create_provider(model=model)
|
||||||
|
if provider.usage_tracker.is_rate_limited(model):
|
||||||
|
logger.error(
|
||||||
|
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
tools = make_discord_tools(
|
tools = make_discord_tools(
|
||||||
message.recipient_user.system_user,
|
message.recipient_user.system_user,
|
||||||
message.from_user,
|
message.from_user,
|
||||||
@ -73,7 +80,6 @@ def call_llm(
|
|||||||
system_prompt += comm_channel_prompt(
|
system_prompt += comm_channel_prompt(
|
||||||
session, message.recipient_user, message.channel
|
session, message.recipient_user, message.channel
|
||||||
)
|
)
|
||||||
provider = create_provider(model=model)
|
|
||||||
messages = previous_messages(
|
messages = previous_messages(
|
||||||
session,
|
session,
|
||||||
message.recipient_user and message.recipient_user.id,
|
message.recipient_user and message.recipient_user.id,
|
||||||
|
|||||||
@ -1,7 +1,24 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
import redis # noqa: F401 # pragma: no cover - optional test dependency
|
||||||
|
except ModuleNotFoundError: # pragma: no cover - import guard for test envs
|
||||||
|
import sys
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
class _RedisStub(SimpleNamespace):
|
||||||
|
class Redis: # type: ignore[no-redef]
|
||||||
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"The 'redis' package is required to use RedisUsageTracker"
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.modules.setdefault("redis", _RedisStub())
|
||||||
|
|
||||||
|
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
|
||||||
from memory.common.llms.usage_tracker import (
|
from memory.common.llms.usage_tracker import (
|
||||||
InMemoryUsageTracker,
|
InMemoryUsageTracker,
|
||||||
RateLimitConfig,
|
RateLimitConfig,
|
||||||
@ -9,6 +26,24 @@ from memory.common.llms.usage_tracker import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRedis:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._store: dict[str, str] = {}
|
||||||
|
|
||||||
|
def get(self, key: str) -> str | None:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
def set(self, key: str, value: str) -> None:
|
||||||
|
self._store[key] = value
|
||||||
|
|
||||||
|
def scan_iter(self, match: str) -> Iterable[str]:
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
|
||||||
|
for key in list(self._store.keys()):
|
||||||
|
if fnmatch(key, match):
|
||||||
|
yield key
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tracker() -> InMemoryUsageTracker:
|
def tracker() -> InMemoryUsageTracker:
|
||||||
config = RateLimitConfig(
|
config = RateLimitConfig(
|
||||||
@ -25,6 +60,23 @@ def tracker() -> InMemoryUsageTracker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def redis_tracker() -> RedisUsageTracker:
|
||||||
|
config = RateLimitConfig(
|
||||||
|
window=timedelta(minutes=1),
|
||||||
|
max_input_tokens=1_000,
|
||||||
|
max_output_tokens=2_000,
|
||||||
|
max_total_tokens=2_500,
|
||||||
|
)
|
||||||
|
return RedisUsageTracker(
|
||||||
|
{
|
||||||
|
"anthropic/claude-3": config,
|
||||||
|
"anthropic/haiku": config,
|
||||||
|
},
|
||||||
|
redis_client=FakeRedis(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"window, kwargs",
|
"window, kwargs",
|
||||||
[
|
[
|
||||||
@ -139,6 +191,22 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
|||||||
assert tracker.is_rate_limited("openai/gpt-4o")
|
assert tracker.is_rate_limited("openai/gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
|
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||||
|
|
||||||
|
allowance = redis_tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.input_tokens == 900
|
||||||
|
|
||||||
|
breakdown = redis_tracker.get_usage_breakdown()
|
||||||
|
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
|
||||||
|
|
||||||
|
items = dict(redis_tracker.iter_state_items())
|
||||||
|
assert set(items.keys()) == {"anthropic/claude-3", "anthropic/haiku"}
|
||||||
|
|
||||||
|
|
||||||
def test_usage_tracker_base_not_instantiable() -> None:
|
def test_usage_tracker_base_not_instantiable() -> None:
|
||||||
class DummyTracker(UsageTracker):
|
class DummyTracker(UsageTracker):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user