mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
fix bugs
This commit is contained in:
parent
814090dccb
commit
57145ac7b4
@ -57,9 +57,17 @@ RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
|
||||
def get_broker_url() -> str:
|
||||
protocol = settings.CELERY_BROKER_TYPE
|
||||
user = safequote(settings.CELERY_BROKER_USER)
|
||||
password = safequote(settings.CELERY_BROKER_PASSWORD)
|
||||
password = safequote(settings.CELERY_BROKER_PASSWORD or "")
|
||||
host = settings.CELERY_BROKER_HOST
|
||||
return f"{protocol}://{user}:{password}@{host}"
|
||||
|
||||
if password:
|
||||
url = f"{protocol}://{user}:{password}@{host}"
|
||||
else:
|
||||
url = f"{protocol}://{host}"
|
||||
|
||||
if protocol == "redis":
|
||||
url += f"/{settings.REDIS_DB}"
|
||||
return url
|
||||
|
||||
|
||||
app = Celery(
|
||||
|
||||
@ -35,8 +35,7 @@ class MessageProcessor:
|
||||
)
|
||||
chattiness_threshold = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=50,
|
||||
nullable=True,
|
||||
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
||||
)
|
||||
|
||||
|
||||
@ -138,6 +138,8 @@ class DiscordBotUser(BotUser):
|
||||
email: str,
|
||||
api_key: str | None = None,
|
||||
) -> "DiscordBotUser":
|
||||
if not discord_users:
|
||||
raise ValueError("discord_users must be provided")
|
||||
bot = super().create_with_api_key(name, email, api_key)
|
||||
bot.discord_users = discord_users
|
||||
return bot
|
||||
|
||||
@ -23,6 +23,8 @@ logger = logging.getLogger(__name__)
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
||||
|
||||
provider = "anthropic"
|
||||
|
||||
# Models that support extended thinking
|
||||
THINKING_MODELS = {
|
||||
"claude-opus-4",
|
||||
@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
Usage(
|
||||
input_tokens=usage.input_tokens,
|
||||
output_tokens=usage.output_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
total_tokens=usage.input_tokens + usage.output_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -205,6 +205,8 @@ class LLMSettings:
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Base class for LLM providers."""
|
||||
|
||||
provider: str = ""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
|
||||
):
|
||||
@ -234,8 +236,14 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
def log_usage(self, usage: Usage):
|
||||
"""Log usage data."""
|
||||
logger.debug(f"Token usage: {usage.to_dict()}")
|
||||
print(f"Token usage: {usage.to_dict()}")
|
||||
logger.debug(
|
||||
f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total"
|
||||
)
|
||||
self.usage_tracker.record_usage(
|
||||
model=f"{self.provider}/{self.model}",
|
||||
input_tokens=usage.input_tokens,
|
||||
output_tokens=usage.output_tokens,
|
||||
)
|
||||
|
||||
def execute_tool(
|
||||
self,
|
||||
|
||||
@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""OpenAI LLM provider with streaming and tool support."""
|
||||
|
||||
provider = "openai"
|
||||
|
||||
# Models that use max_completion_tokens instead of max_tokens
|
||||
# These are reasoning models with different parameter requirements
|
||||
NON_REASONING_MODELS = {"gpt-4o"}
|
||||
|
||||
@ -33,7 +33,7 @@ class RedisUsageTracker(UsageTracker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: dict[str, RateLimitConfig],
|
||||
configs: dict[str, RateLimitConfig] | None = None,
|
||||
default_config: RateLimitConfig | None = None,
|
||||
*,
|
||||
redis_client: RedisClientProtocol | None = None,
|
||||
@ -41,12 +41,7 @@ class RedisUsageTracker(UsageTracker):
|
||||
) -> None:
|
||||
super().__init__(configs=configs, default_config=default_config)
|
||||
if redis_client is None:
|
||||
redis_client = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=int(settings.REDIS_PORT),
|
||||
db=int(settings.REDIS_DB),
|
||||
decode_responses=False,
|
||||
)
|
||||
redis_client = redis.Redis.from_url(settings.REDIS_URL)
|
||||
self._redis = redis_client
|
||||
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
|
||||
self._key_prefix = prefix.rstrip(":")
|
||||
|
||||
@ -7,6 +7,8 @@ from datetime import datetime, timedelta, timezone
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
@ -111,12 +113,14 @@ class UsageBreakdown:
|
||||
|
||||
def split_model_key(model: str) -> tuple[str, str]:
|
||||
if "/" not in model:
|
||||
raise ValueError("model must be formatted as '<provider>/<model_name>'")
|
||||
raise ValueError(
|
||||
f"model must be formatted as '<provider>/<model_name>': got '{model}'"
|
||||
)
|
||||
|
||||
provider, model_name = model.split("/", maxsplit=1)
|
||||
if not provider or not model_name:
|
||||
raise ValueError(
|
||||
"model must include both provider and model name separated by '/'"
|
||||
f"model must include both provider and model name separated by '/': got '{model}'"
|
||||
)
|
||||
return provider, model_name
|
||||
|
||||
@ -126,11 +130,15 @@ class UsageTracker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: dict[str, RateLimitConfig],
|
||||
configs: dict[str, RateLimitConfig] | None = None,
|
||||
default_config: RateLimitConfig | None = None,
|
||||
) -> None:
|
||||
self._configs = configs
|
||||
self._default_config = default_config
|
||||
self._configs = configs or {}
|
||||
self._default_config = default_config or RateLimitConfig(
|
||||
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
|
||||
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||
)
|
||||
self._lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@ -213,15 +221,14 @@ class UsageTracker:
|
||||
"""
|
||||
|
||||
split_model_key(model)
|
||||
key = model
|
||||
with self._lock:
|
||||
config = self._get_config(key)
|
||||
config = self._get_config(model)
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
state = self.get_state(key)
|
||||
state = self.get_state(model)
|
||||
self._prune_expired_events(state, config, now=timestamp)
|
||||
self.save_state(key, state)
|
||||
self.save_state(model, state)
|
||||
|
||||
if config.max_total_tokens is None:
|
||||
total_remaining = None
|
||||
@ -253,8 +260,8 @@ class UsageTracker:
|
||||
|
||||
with self._lock:
|
||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||
for key, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(key)
|
||||
for model, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(model)
|
||||
if provider and provider != prov:
|
||||
continue
|
||||
if model and model != model_name:
|
||||
@ -296,8 +303,8 @@ class UsageTracker:
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_config(self, key: str) -> RateLimitConfig | None:
|
||||
return self._configs.get(key) or self._default_config
|
||||
def _get_config(self, model: str) -> RateLimitConfig | None:
|
||||
return self._configs.get(model) or self._default_config
|
||||
|
||||
def _prune_expired_events(
|
||||
self,
|
||||
|
||||
@ -31,34 +31,25 @@ def make_db_url(
|
||||
|
||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||
|
||||
# Redis settings
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
if REDIS_PASSWORD:
|
||||
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
else:
|
||||
REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
|
||||
# Broker settings
|
||||
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
|
||||
CELERY_BROKER_USER = os.getenv(
|
||||
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
||||
)
|
||||
CELERY_BROKER_PASSWORD = os.getenv(
|
||||
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
|
||||
)
|
||||
|
||||
|
||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
||||
if not CELERY_BROKER_HOST:
|
||||
if CELERY_BROKER_TYPE == "amqp":
|
||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
||||
elif CELERY_BROKER_TYPE == "redis":
|
||||
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "")
|
||||
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD)
|
||||
|
||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}"
|
||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||
|
||||
|
||||
# File storage settings
|
||||
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
||||
EBOOK_STORAGE_DIR = pathlib.Path(
|
||||
@ -149,6 +140,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
|
||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||
|
||||
DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int(
|
||||
os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30)
|
||||
)
|
||||
DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int(
|
||||
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000)
|
||||
)
|
||||
DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int(
|
||||
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000)
|
||||
)
|
||||
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
|
||||
|
||||
|
||||
# Search settings
|
||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
|
||||
@ -203,7 +203,7 @@ class MessageCollector(commands.Bot):
|
||||
async def setup_hook(self):
|
||||
"""Register slash commands when the bot is ready."""
|
||||
|
||||
register_slash_commands(self)
|
||||
register_slash_commands(self, name=self.user.name)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Called when bot connects to Discord"""
|
||||
|
||||
@ -41,8 +41,13 @@ class CommandContext:
|
||||
CommandHandler = Callable[..., CommandResponse]
|
||||
|
||||
|
||||
def register_slash_commands(bot: discord.Client) -> None:
|
||||
"""Register the collector slash commands on the provided bot."""
|
||||
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
"""Register the collector slash commands on the provided bot.
|
||||
|
||||
Args:
|
||||
bot: Discord bot client
|
||||
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
|
||||
"""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
return
|
||||
@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
|
||||
tree = bot.tree
|
||||
|
||||
@tree.command(name="memory_prompt", description="Show the current system prompt")
|
||||
@tree.command(
|
||||
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def prompt_command(
|
||||
async def show_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_chattiness",
|
||||
description="Show or update the chattiness threshold for the target",
|
||||
name=f"{name}_set_prompt",
|
||||
description="Set the system prompt for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
prompt="The system prompt to set",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def set_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
prompt: str,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_set_prompt,
|
||||
target_user=user,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_chattiness",
|
||||
description="Show or update the chattiness for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
value="Optional new threshold value between 0 and 100",
|
||||
value="Optional new chattiness value between 0 and 100",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def chattiness_command(
|
||||
@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_ignore",
|
||||
name=f"{name}_ignore",
|
||||
description="Toggle whether the bot should ignore messages for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
@tree.command(name="memory_summary", description="Show the stored summary for the target")
|
||||
@tree.command(
|
||||
name=f"{name}_show_summary",
|
||||
description="Show the stored summary for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse:
|
||||
)
|
||||
|
||||
|
||||
def handle_set_prompt(
|
||||
context: CommandContext,
|
||||
*,
|
||||
prompt: str,
|
||||
) -> CommandResponse:
|
||||
setattr(context.target, "system_prompt", prompt)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"Updated system prompt for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_chattiness(
|
||||
context: CommandContext,
|
||||
*,
|
||||
@ -347,20 +392,22 @@ def handle_chattiness(
|
||||
if value is None:
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Chattiness threshold for {context.display_name}: "
|
||||
f"Chattiness for {context.display_name}: "
|
||||
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
||||
)
|
||||
)
|
||||
|
||||
if not 0 <= value <= 100:
|
||||
raise CommandError("Chattiness threshold must be between 0 and 100.")
|
||||
raise CommandError("Chattiness must be between 0 and 100.")
|
||||
|
||||
setattr(model, "chattiness_threshold", value)
|
||||
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Updated chattiness threshold for {context.display_name} "
|
||||
f"to {value}."
|
||||
f"Updated chattiness for {context.display_name} to {value}."
|
||||
"\n"
|
||||
"This can be treated as how much you want the bot to pipe up by itself, as a percentage, "
|
||||
"where 0 is never and 100 is always."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -128,12 +128,19 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
"update_server_summary",
|
||||
],
|
||||
)
|
||||
print("response", response)
|
||||
if not response:
|
||||
return False
|
||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||
return False
|
||||
try:
|
||||
return int(res.group(1)) > message.chattiness_threshold
|
||||
print(
|
||||
"parsed",
|
||||
int(res.group(1)),
|
||||
message.chattiness_threshold,
|
||||
100 - message.chattiness_threshold,
|
||||
)
|
||||
return int(res.group(1)) > 100 - message.chattiness_threshold
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user