From 57145ac7b4d71e6c20b75292d9609b5663951bff Mon Sep 17 00:00:00 2001 From: mruwnik Date: Sat, 1 Nov 2025 19:35:20 +0000 Subject: [PATCH] fix bugs --- src/memory/common/celery_app.py | 12 ++- src/memory/common/db/models/discord.py | 3 +- src/memory/common/db/models/users.py | 2 + src/memory/common/llms/anthropic_provider.py | 4 +- src/memory/common/llms/base.py | 12 ++- src/memory/common/llms/openai_provider.py | 2 + .../common/llms/usage/redis_usage_tracker.py | 9 +-- src/memory/common/llms/usage/usage_tracker.py | 33 +++++---- src/memory/common/settings.py | 45 ++++++------ src/memory/discord/collector.py | 2 +- src/memory/discord/commands.py | 73 +++++++++++++++---- src/memory/workers/tasks/discord.py | 9 ++- 12 files changed, 143 insertions(+), 63 deletions(-) diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index a415c19..03cb3b8 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -57,9 +57,17 @@ RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls" def get_broker_url() -> str: protocol = settings.CELERY_BROKER_TYPE user = safequote(settings.CELERY_BROKER_USER) - password = safequote(settings.CELERY_BROKER_PASSWORD) + password = safequote(settings.CELERY_BROKER_PASSWORD or "") host = settings.CELERY_BROKER_HOST - return f"{protocol}://{user}:{password}@{host}" + + if password: + url = f"{protocol}://{user}:{password}@{host}" + else: + url = f"{protocol}://{host}" + + if protocol == "redis": + url += f"/{settings.REDIS_DB}" + return url app = Celery( diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index 353727e..01f727c 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -35,8 +35,7 @@ class MessageProcessor: ) chattiness_threshold = Column( Integer, - nullable=False, - default=50, + nullable=True, doc="The threshold for the bot to continue the conversation, between 0 and 100.", ) diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py index 09f8dd8..a86cf03 100644 --- a/src/memory/common/db/models/users.py +++ b/src/memory/common/db/models/users.py @@ -138,6 +138,8 @@ class DiscordBotUser(BotUser): email: str, api_key: str | None = None, ) -> "DiscordBotUser": + if not discord_users: + raise ValueError("discord_users must be provided") bot = super().create_with_api_key(name, email, api_key) bot.discord_users = discord_users return bot diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index 0ef8691..a09c2ba 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) class AnthropicProvider(BaseLLMProvider): """Anthropic LLM provider with streaming, tool support, and extended thinking.""" + provider = "anthropic" + # Models that support extended thinking THINKING_MODELS = { "claude-opus-4", @@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider): Usage( input_tokens=usage.input_tokens, output_tokens=usage.output_tokens, - total_tokens=usage.total_tokens, + total_tokens=usage.input_tokens + usage.output_tokens, ) ) diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index 709d810..92f20a2 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -205,6 +205,8 @@ class LLMSettings: class BaseLLMProvider(ABC): """Base class for LLM providers.""" + provider: str = "" + def __init__( self, api_key: str, model: str, usage_tracker: UsageTracker | None = None ): @@ -234,8 +236,14 @@ class BaseLLMProvider(ABC): def log_usage(self, usage: Usage): """Log usage data.""" - logger.debug(f"Token usage: {usage.to_dict()}") - print(f"Token usage: {usage.to_dict()}") + logger.debug( + f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total" + ) + self.usage_tracker.record_usage( + model=f"{self.provider}/{self.model}", + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + ) def execute_tool( self, diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py index 811e0d0..3548beb 100644 --- a/src/memory/common/llms/openai_provider.py +++ b/src/memory/common/llms/openai_provider.py @@ -25,6 +25,8 @@ logger = logging.getLogger(__name__) class OpenAIProvider(BaseLLMProvider): """OpenAI LLM provider with streaming and tool support.""" + provider = "openai" + # Models that use max_completion_tokens instead of max_tokens # These are reasoning models with different parameter requirements NON_REASONING_MODELS = {"gpt-4o"} diff --git a/src/memory/common/llms/usage/redis_usage_tracker.py b/src/memory/common/llms/usage/redis_usage_tracker.py index 468e9da..54fda3a 100644 --- a/src/memory/common/llms/usage/redis_usage_tracker.py +++ b/src/memory/common/llms/usage/redis_usage_tracker.py @@ -33,7 +33,7 @@ class RedisUsageTracker(UsageTracker): def __init__( self, - configs: dict[str, RateLimitConfig], + configs: dict[str, RateLimitConfig] | None = None, default_config: RateLimitConfig | None = None, *, redis_client: RedisClientProtocol | None = None, @@ -41,12 +41,7 @@ class RedisUsageTracker(UsageTracker): ) -> None: super().__init__(configs=configs, default_config=default_config) if redis_client is None: - redis_client = redis.Redis( - host=settings.REDIS_HOST, - port=int(settings.REDIS_PORT), - db=int(settings.REDIS_DB), - decode_responses=False, - ) + redis_client = redis.Redis.from_url(settings.REDIS_URL) self._redis = redis_client prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX self._key_prefix = prefix.rstrip(":") diff --git a/src/memory/common/llms/usage/usage_tracker.py b/src/memory/common/llms/usage/usage_tracker.py index fbf9750..1cda909 100644 --- a/src/memory/common/llms/usage/usage_tracker.py +++ b/src/memory/common/llms/usage/usage_tracker.py @@ -7,6 +7,8 @@ from datetime import datetime, timedelta, timezone from threading import Lock from typing import Any +from memory.common import settings + @dataclass(frozen=True) class RateLimitConfig: @@ -111,12 +113,14 @@ class UsageBreakdown: def split_model_key(model: str) -> tuple[str, str]: if "/" not in model: - raise ValueError("model must be formatted as '/'") + raise ValueError( + f"model must be formatted as '/': got '{model}'" + ) provider, model_name = model.split("/", maxsplit=1) if not provider or not model_name: raise ValueError( - "model must include both provider and model name separated by '/'" + f"model must include both provider and model name separated by '/': got '{model}'" ) return provider, model_name @@ -126,11 +130,15 @@ class UsageTracker: def __init__( self, - configs: dict[str, RateLimitConfig], + configs: dict[str, RateLimitConfig] | None = None, default_config: RateLimitConfig | None = None, ) -> None: - self._configs = configs - self._default_config = default_config + self._configs = configs or {} + self._default_config = default_config or RateLimitConfig( + window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES), + max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS, + max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS, + ) self._lock = Lock() # ------------------------------------------------------------------ @@ -213,15 +221,14 @@ class UsageTracker: """ split_model_key(model) - key = model with self._lock: - config = self._get_config(key) + config = self._get_config(model) if config is None: return None - state = self.get_state(key) + state = self.get_state(model) self._prune_expired_events(state, config, now=timestamp) - self.save_state(key, state) + self.save_state(model, state) if config.max_total_tokens is None: total_remaining = None @@ -253,8 +260,8 @@ class UsageTracker: with self._lock: providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict) - for key, state in self.iter_state_items(): - prov, model_name = split_model_key(key) + for model, state in self.iter_state_items(): + prov, model_name = split_model_key(model) if provider and provider != prov: continue if model and model != model_name: @@ -296,8 +303,8 @@ class UsageTracker: # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ - def _get_config(self, key: str) -> RateLimitConfig | None: - return self._configs.get(key) or self._default_config + def _get_config(self, model: str) -> RateLimitConfig | None: + return self._configs.get(model) or self._default_config def _prune_expired_events( self, diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 9ea8c64..5617f5b 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -31,34 +31,25 @@ def make_db_url( DB_URL = os.getenv("DATABASE_URL", make_db_url()) +# Redis settings +REDIS_HOST = os.getenv("REDIS_HOST", "redis") +REDIS_PORT = os.getenv("REDIS_PORT", "6379") +REDIS_DB = os.getenv("REDIS_DB", "0") +REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None) +if REDIS_PASSWORD: + REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" +else: + REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" # Broker settings CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory") CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower() -REDIS_HOST = os.getenv("REDIS_HOST", "redis") -REDIS_PORT = os.getenv("REDIS_PORT", "6379") -REDIS_DB = os.getenv("REDIS_DB", "0") -LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage") -CELERY_BROKER_USER = os.getenv( - "CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else "" -) -CELERY_BROKER_PASSWORD = os.getenv( - "CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb" -) - - -CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") -if not CELERY_BROKER_HOST: - if CELERY_BROKER_TYPE == "amqp": - RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq") - RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672") - CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//" - elif CELERY_BROKER_TYPE == "redis": - CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" +CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "") +CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD) +CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}" CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}") - # File storage settings FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files")) EBOOK_STORAGE_DIR = pathlib.Path( @@ -149,6 +140,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5") RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307") MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000)) +DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int( + os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30) +) +DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int( + os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000) +) +DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int( + os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000) +) +LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage") + + # Search settings ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True) ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True) diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index ab0803c..a898458 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -203,7 +203,7 @@ class MessageCollector(commands.Bot): async def setup_hook(self): """Register slash commands when the bot is ready.""" - register_slash_commands(self) + register_slash_commands(self, name=self.user.name) async def on_ready(self): """Called when bot connects to Discord""" diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py index 4697a3e..bd3a234 100644 --- a/src/memory/discord/commands.py +++ b/src/memory/discord/commands.py @@ -41,8 +41,13 @@ class CommandContext: CommandHandler = Callable[..., CommandResponse] -def register_slash_commands(bot: discord.Client) -> None: - """Register the collector slash commands on the provided bot.""" +def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: + """Register the collector slash commands on the provided bot. + + Args: + bot: Discord bot client + name: Prefix for command names (e.g., "memory" creates "memory_prompt") + """ if getattr(bot, "_memory_commands_registered", False): return @@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None: tree = bot.tree - @tree.command(name="memory_prompt", description="Show the current system prompt") + @tree.command( + name=f"{name}_show_prompt", description="Show the current system prompt" + ) @discord.app_commands.describe( scope="Which configuration to inspect", user="Target user when the scope is 'user'", ) - async def prompt_command( + async def show_prompt_command( interaction: discord.Interaction, scope: ScopeLiteral, user: discord.User | None = None, @@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None: ) @tree.command( - name="memory_chattiness", - description="Show or update the chattiness threshold for the target", + name=f"{name}_set_prompt", + description="Set the system prompt for the target", + ) + @discord.app_commands.describe( + scope="Which configuration to modify", + prompt="The system prompt to set", + user="Target user when the scope is 'user'", + ) + async def set_prompt_command( + interaction: discord.Interaction, + scope: ScopeLiteral, + prompt: str, + user: discord.User | None = None, + ) -> None: + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_set_prompt, + target_user=user, + prompt=prompt, + ) + + @tree.command( + name=f"{name}_chattiness", + description="Show or update the chattiness for the target", ) @discord.app_commands.describe( scope="Which configuration to inspect", - value="Optional new threshold value between 0 and 100", + value="Optional new chattiness value between 0 and 100", user="Target user when the scope is 'user'", ) async def chattiness_command( @@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None: ) @tree.command( - name="memory_ignore", + name=f"{name}_ignore", description="Toggle whether the bot should ignore messages for the target", ) @discord.app_commands.describe( @@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None: ignore_enabled=enabled, ) - @tree.command(name="memory_summary", description="Show the stored summary for the target") + @tree.command( + name=f"{name}_show_summary", + description="Show the stored summary for the target", + ) @discord.app_commands.describe( scope="Which configuration to inspect", user="Target user when the scope is 'user'", @@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse: ) +def handle_set_prompt( + context: CommandContext, + *, + prompt: str, +) -> CommandResponse: + setattr(context.target, "system_prompt", prompt) + + return CommandResponse( + content=f"Updated system prompt for {context.display_name}.", + ) + + def handle_chattiness( context: CommandContext, *, @@ -347,20 +392,22 @@ def handle_chattiness( if value is None: return CommandResponse( content=( - f"Chattiness threshold for {context.display_name}: " + f"Chattiness for {context.display_name}: " f"{getattr(model, 'chattiness_threshold', 'not set')}" ) ) if not 0 <= value <= 100: - raise CommandError("Chattiness threshold must be between 0 and 100.") + raise CommandError("Chattiness must be between 0 and 100.") setattr(model, "chattiness_threshold", value) return CommandResponse( content=( - f"Updated chattiness threshold for {context.display_name} " - f"to {value}." + f"Updated chattiness for {context.display_name} to {value}." + "\n" + "This can be treated as how much you want the bot to pipe up by itself, as a percentage, " + "where 0 is never and 100 is always." ) ) diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index c1cdda0..b8a107c 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -128,12 +128,19 @@ def should_process(message: DiscordMessage) -> bool: "update_server_summary", ], ) + print("response", response) if not response: return False if not (res := re.search(r"(.*)", response)): return False try: - return int(res.group(1)) > message.chattiness_threshold + print( + "parsed", + int(res.group(1)), + message.chattiness_threshold, + 100 - message.chattiness_threshold, + ) + return int(res.group(1)) > 100 - message.chattiness_threshold except ValueError: return False