This commit is contained in:
mruwnik 2025-11-01 19:35:20 +00:00
parent 814090dccb
commit 57145ac7b4
12 changed files with 143 additions and 63 deletions

View File

@ -57,9 +57,17 @@ RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
def get_broker_url() -> str: def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE protocol = settings.CELERY_BROKER_TYPE
user = safequote(settings.CELERY_BROKER_USER) user = safequote(settings.CELERY_BROKER_USER)
password = safequote(settings.CELERY_BROKER_PASSWORD) password = safequote(settings.CELERY_BROKER_PASSWORD or "")
host = settings.CELERY_BROKER_HOST host = settings.CELERY_BROKER_HOST
return f"{protocol}://{user}:{password}@{host}"
if password:
url = f"{protocol}://{user}:{password}@{host}"
else:
url = f"{protocol}://{host}"
if protocol == "redis":
url += f"/{settings.REDIS_DB}"
return url
app = Celery( app = Celery(

View File

@ -35,8 +35,7 @@ class MessageProcessor:
) )
chattiness_threshold = Column( chattiness_threshold = Column(
Integer, Integer,
nullable=False, nullable=True,
default=50,
doc="The threshold for the bot to continue the conversation, between 0 and 100.", doc="The threshold for the bot to continue the conversation, between 0 and 100.",
) )

View File

@ -138,6 +138,8 @@ class DiscordBotUser(BotUser):
email: str, email: str,
api_key: str | None = None, api_key: str | None = None,
) -> "DiscordBotUser": ) -> "DiscordBotUser":
if not discord_users:
raise ValueError("discord_users must be provided")
bot = super().create_with_api_key(name, email, api_key) bot = super().create_with_api_key(name, email, api_key)
bot.discord_users = discord_users bot.discord_users = discord_users
return bot return bot

View File

@ -23,6 +23,8 @@ logger = logging.getLogger(__name__)
class AnthropicProvider(BaseLLMProvider): class AnthropicProvider(BaseLLMProvider):
"""Anthropic LLM provider with streaming, tool support, and extended thinking.""" """Anthropic LLM provider with streaming, tool support, and extended thinking."""
provider = "anthropic"
# Models that support extended thinking # Models that support extended thinking
THINKING_MODELS = { THINKING_MODELS = {
"claude-opus-4", "claude-opus-4",
@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider):
Usage( Usage(
input_tokens=usage.input_tokens, input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens, output_tokens=usage.output_tokens,
total_tokens=usage.total_tokens, total_tokens=usage.input_tokens + usage.output_tokens,
) )
) )

View File

@ -205,6 +205,8 @@ class LLMSettings:
class BaseLLMProvider(ABC): class BaseLLMProvider(ABC):
"""Base class for LLM providers.""" """Base class for LLM providers."""
provider: str = ""
def __init__( def __init__(
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
): ):
@ -234,8 +236,14 @@ class BaseLLMProvider(ABC):
def log_usage(self, usage: Usage): def log_usage(self, usage: Usage):
"""Log usage data.""" """Log usage data."""
logger.debug(f"Token usage: {usage.to_dict()}") logger.debug(
print(f"Token usage: {usage.to_dict()}") f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total"
)
self.usage_tracker.record_usage(
model=f"{self.provider}/{self.model}",
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
)
def execute_tool( def execute_tool(
self, self,

View File

@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(BaseLLMProvider): class OpenAIProvider(BaseLLMProvider):
"""OpenAI LLM provider with streaming and tool support.""" """OpenAI LLM provider with streaming and tool support."""
provider = "openai"
# Models that use max_completion_tokens instead of max_tokens # Models that use max_completion_tokens instead of max_tokens
# These are reasoning models with different parameter requirements # These are reasoning models with different parameter requirements
NON_REASONING_MODELS = {"gpt-4o"} NON_REASONING_MODELS = {"gpt-4o"}

View File

@ -33,7 +33,7 @@ class RedisUsageTracker(UsageTracker):
def __init__( def __init__(
self, self,
configs: dict[str, RateLimitConfig], configs: dict[str, RateLimitConfig] | None = None,
default_config: RateLimitConfig | None = None, default_config: RateLimitConfig | None = None,
*, *,
redis_client: RedisClientProtocol | None = None, redis_client: RedisClientProtocol | None = None,
@ -41,12 +41,7 @@ class RedisUsageTracker(UsageTracker):
) -> None: ) -> None:
super().__init__(configs=configs, default_config=default_config) super().__init__(configs=configs, default_config=default_config)
if redis_client is None: if redis_client is None:
redis_client = redis.Redis( redis_client = redis.Redis.from_url(settings.REDIS_URL)
host=settings.REDIS_HOST,
port=int(settings.REDIS_PORT),
db=int(settings.REDIS_DB),
decode_responses=False,
)
self._redis = redis_client self._redis = redis_client
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
self._key_prefix = prefix.rstrip(":") self._key_prefix = prefix.rstrip(":")

View File

@ -7,6 +7,8 @@ from datetime import datetime, timedelta, timezone
from threading import Lock from threading import Lock
from typing import Any from typing import Any
from memory.common import settings
@dataclass(frozen=True) @dataclass(frozen=True)
class RateLimitConfig: class RateLimitConfig:
@ -111,12 +113,14 @@ class UsageBreakdown:
def split_model_key(model: str) -> tuple[str, str]: def split_model_key(model: str) -> tuple[str, str]:
if "/" not in model: if "/" not in model:
raise ValueError("model must be formatted as '<provider>/<model_name>'") raise ValueError(
f"model must be formatted as '<provider>/<model_name>': got '{model}'"
)
provider, model_name = model.split("/", maxsplit=1) provider, model_name = model.split("/", maxsplit=1)
if not provider or not model_name: if not provider or not model_name:
raise ValueError( raise ValueError(
"model must include both provider and model name separated by '/'" f"model must include both provider and model name separated by '/': got '{model}'"
) )
return provider, model_name return provider, model_name
@ -126,11 +130,15 @@ class UsageTracker:
def __init__( def __init__(
self, self,
configs: dict[str, RateLimitConfig], configs: dict[str, RateLimitConfig] | None = None,
default_config: RateLimitConfig | None = None, default_config: RateLimitConfig | None = None,
) -> None: ) -> None:
self._configs = configs self._configs = configs or {}
self._default_config = default_config self._default_config = default_config or RateLimitConfig(
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
)
self._lock = Lock() self._lock = Lock()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -213,15 +221,14 @@ class UsageTracker:
""" """
split_model_key(model) split_model_key(model)
key = model
with self._lock: with self._lock:
config = self._get_config(key) config = self._get_config(model)
if config is None: if config is None:
return None return None
state = self.get_state(key) state = self.get_state(model)
self._prune_expired_events(state, config, now=timestamp) self._prune_expired_events(state, config, now=timestamp)
self.save_state(key, state) self.save_state(model, state)
if config.max_total_tokens is None: if config.max_total_tokens is None:
total_remaining = None total_remaining = None
@ -253,8 +260,8 @@ class UsageTracker:
with self._lock: with self._lock:
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict) providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
for key, state in self.iter_state_items(): for model, state in self.iter_state_items():
prov, model_name = split_model_key(key) prov, model_name = split_model_key(model)
if provider and provider != prov: if provider and provider != prov:
continue continue
if model and model != model_name: if model and model != model_name:
@ -296,8 +303,8 @@ class UsageTracker:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Internal helpers # Internal helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _get_config(self, key: str) -> RateLimitConfig | None: def _get_config(self, model: str) -> RateLimitConfig | None:
return self._configs.get(key) or self._default_config return self._configs.get(model) or self._default_config
def _prune_expired_events( def _prune_expired_events(
self, self,

View File

@ -31,34 +31,25 @@ def make_db_url(
DB_URL = os.getenv("DATABASE_URL", make_db_url()) DB_URL = os.getenv("DATABASE_URL", make_db_url())
# Redis settings
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
REDIS_DB = os.getenv("REDIS_DB", "0")
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
if REDIS_PASSWORD:
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
else:
REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
# Broker settings # Broker settings
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory") CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower() CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
REDIS_HOST = os.getenv("REDIS_HOST", "redis") CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "")
REDIS_PORT = os.getenv("REDIS_PORT", "6379") CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD)
REDIS_DB = os.getenv("REDIS_DB", "0")
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
CELERY_BROKER_USER = os.getenv(
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
)
CELERY_BROKER_PASSWORD = os.getenv(
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
)
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
if not CELERY_BROKER_HOST:
if CELERY_BROKER_TYPE == "amqp":
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
elif CELERY_BROKER_TYPE == "redis":
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}"
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}") CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
# File storage settings # File storage settings
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files")) FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
EBOOK_STORAGE_DIR = pathlib.Path( EBOOK_STORAGE_DIR = pathlib.Path(
@ -149,6 +140,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307") RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000)) MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30)
)
DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000)
)
DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000)
)
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
# Search settings # Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True) ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True) ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)

View File

@ -203,7 +203,7 @@ class MessageCollector(commands.Bot):
async def setup_hook(self): async def setup_hook(self):
"""Register slash commands when the bot is ready.""" """Register slash commands when the bot is ready."""
register_slash_commands(self) register_slash_commands(self, name=self.user.name)
async def on_ready(self): async def on_ready(self):
"""Called when bot connects to Discord""" """Called when bot connects to Discord"""

View File

@ -41,8 +41,13 @@ class CommandContext:
CommandHandler = Callable[..., CommandResponse] CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client) -> None: def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
"""Register the collector slash commands on the provided bot.""" """Register the collector slash commands on the provided bot.
Args:
bot: Discord bot client
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
"""
if getattr(bot, "_memory_commands_registered", False): if getattr(bot, "_memory_commands_registered", False):
return return
@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None:
tree = bot.tree tree = bot.tree
@tree.command(name="memory_prompt", description="Show the current system prompt") @tree.command(
name=f"{name}_show_prompt", description="Show the current system prompt"
)
@discord.app_commands.describe( @discord.app_commands.describe(
scope="Which configuration to inspect", scope="Which configuration to inspect",
user="Target user when the scope is 'user'", user="Target user when the scope is 'user'",
) )
async def prompt_command( async def show_prompt_command(
interaction: discord.Interaction, interaction: discord.Interaction,
scope: ScopeLiteral, scope: ScopeLiteral,
user: discord.User | None = None, user: discord.User | None = None,
@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None:
) )
@tree.command( @tree.command(
name="memory_chattiness", name=f"{name}_set_prompt",
description="Show or update the chattiness threshold for the target", description="Set the system prompt for the target",
)
@discord.app_commands.describe(
scope="Which configuration to modify",
prompt="The system prompt to set",
user="Target user when the scope is 'user'",
)
async def set_prompt_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
prompt: str,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_set_prompt,
target_user=user,
prompt=prompt,
)
@tree.command(
name=f"{name}_chattiness",
description="Show or update the chattiness for the target",
) )
@discord.app_commands.describe( @discord.app_commands.describe(
scope="Which configuration to inspect", scope="Which configuration to inspect",
value="Optional new threshold value between 0 and 100", value="Optional new chattiness value between 0 and 100",
user="Target user when the scope is 'user'", user="Target user when the scope is 'user'",
) )
async def chattiness_command( async def chattiness_command(
@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None:
) )
@tree.command( @tree.command(
name="memory_ignore", name=f"{name}_ignore",
description="Toggle whether the bot should ignore messages for the target", description="Toggle whether the bot should ignore messages for the target",
) )
@discord.app_commands.describe( @discord.app_commands.describe(
@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None:
ignore_enabled=enabled, ignore_enabled=enabled,
) )
@tree.command(name="memory_summary", description="Show the stored summary for the target") @tree.command(
name=f"{name}_show_summary",
description="Show the stored summary for the target",
)
@discord.app_commands.describe( @discord.app_commands.describe(
scope="Which configuration to inspect", scope="Which configuration to inspect",
user="Target user when the scope is 'user'", user="Target user when the scope is 'user'",
@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse:
) )
def handle_set_prompt(
context: CommandContext,
*,
prompt: str,
) -> CommandResponse:
setattr(context.target, "system_prompt", prompt)
return CommandResponse(
content=f"Updated system prompt for {context.display_name}.",
)
def handle_chattiness( def handle_chattiness(
context: CommandContext, context: CommandContext,
*, *,
@ -347,20 +392,22 @@ def handle_chattiness(
if value is None: if value is None:
return CommandResponse( return CommandResponse(
content=( content=(
f"Chattiness threshold for {context.display_name}: " f"Chattiness for {context.display_name}: "
f"{getattr(model, 'chattiness_threshold', 'not set')}" f"{getattr(model, 'chattiness_threshold', 'not set')}"
) )
) )
if not 0 <= value <= 100: if not 0 <= value <= 100:
raise CommandError("Chattiness threshold must be between 0 and 100.") raise CommandError("Chattiness must be between 0 and 100.")
setattr(model, "chattiness_threshold", value) setattr(model, "chattiness_threshold", value)
return CommandResponse( return CommandResponse(
content=( content=(
f"Updated chattiness threshold for {context.display_name} " f"Updated chattiness for {context.display_name} to {value}."
f"to {value}." "\n"
"This can be treated as how much you want the bot to pipe up by itself, as a percentage, "
"where 0 is never and 100 is always."
) )
) )

View File

@ -128,12 +128,19 @@ def should_process(message: DiscordMessage) -> bool:
"update_server_summary", "update_server_summary",
], ],
) )
print("response", response)
if not response: if not response:
return False return False
if not (res := re.search(r"<number>(.*)</number>", response)): if not (res := re.search(r"<number>(.*)</number>", response)):
return False return False
try: try:
return int(res.group(1)) > message.chattiness_threshold print(
"parsed",
int(res.group(1)),
message.chattiness_threshold,
100 - message.chattiness_threshold,
)
return int(res.group(1)) > 100 - message.chattiness_threshold
except ValueError: except ValueError:
return False return False