mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
fix bugs
This commit is contained in:
parent
814090dccb
commit
57145ac7b4
@ -57,9 +57,17 @@ RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
|
|||||||
def get_broker_url() -> str:
|
def get_broker_url() -> str:
|
||||||
protocol = settings.CELERY_BROKER_TYPE
|
protocol = settings.CELERY_BROKER_TYPE
|
||||||
user = safequote(settings.CELERY_BROKER_USER)
|
user = safequote(settings.CELERY_BROKER_USER)
|
||||||
password = safequote(settings.CELERY_BROKER_PASSWORD)
|
password = safequote(settings.CELERY_BROKER_PASSWORD or "")
|
||||||
host = settings.CELERY_BROKER_HOST
|
host = settings.CELERY_BROKER_HOST
|
||||||
return f"{protocol}://{user}:{password}@{host}"
|
|
||||||
|
if password:
|
||||||
|
url = f"{protocol}://{user}:{password}@{host}"
|
||||||
|
else:
|
||||||
|
url = f"{protocol}://{host}"
|
||||||
|
|
||||||
|
if protocol == "redis":
|
||||||
|
url += f"/{settings.REDIS_DB}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
app = Celery(
|
app = Celery(
|
||||||
|
|||||||
@ -35,8 +35,7 @@ class MessageProcessor:
|
|||||||
)
|
)
|
||||||
chattiness_threshold = Column(
|
chattiness_threshold = Column(
|
||||||
Integer,
|
Integer,
|
||||||
nullable=False,
|
nullable=True,
|
||||||
default=50,
|
|
||||||
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -138,6 +138,8 @@ class DiscordBotUser(BotUser):
|
|||||||
email: str,
|
email: str,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
) -> "DiscordBotUser":
|
) -> "DiscordBotUser":
|
||||||
|
if not discord_users:
|
||||||
|
raise ValueError("discord_users must be provided")
|
||||||
bot = super().create_with_api_key(name, email, api_key)
|
bot = super().create_with_api_key(name, email, api_key)
|
||||||
bot.discord_users = discord_users
|
bot.discord_users = discord_users
|
||||||
return bot
|
return bot
|
||||||
|
|||||||
@ -23,6 +23,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class AnthropicProvider(BaseLLMProvider):
|
class AnthropicProvider(BaseLLMProvider):
|
||||||
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
||||||
|
|
||||||
|
provider = "anthropic"
|
||||||
|
|
||||||
# Models that support extended thinking
|
# Models that support extended thinking
|
||||||
THINKING_MODELS = {
|
THINKING_MODELS = {
|
||||||
"claude-opus-4",
|
"claude-opus-4",
|
||||||
@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
Usage(
|
Usage(
|
||||||
input_tokens=usage.input_tokens,
|
input_tokens=usage.input_tokens,
|
||||||
output_tokens=usage.output_tokens,
|
output_tokens=usage.output_tokens,
|
||||||
total_tokens=usage.total_tokens,
|
total_tokens=usage.input_tokens + usage.output_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -205,6 +205,8 @@ class LLMSettings:
|
|||||||
class BaseLLMProvider(ABC):
|
class BaseLLMProvider(ABC):
|
||||||
"""Base class for LLM providers."""
|
"""Base class for LLM providers."""
|
||||||
|
|
||||||
|
provider: str = ""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
|
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
|
||||||
):
|
):
|
||||||
@ -234,8 +236,14 @@ class BaseLLMProvider(ABC):
|
|||||||
|
|
||||||
def log_usage(self, usage: Usage):
|
def log_usage(self, usage: Usage):
|
||||||
"""Log usage data."""
|
"""Log usage data."""
|
||||||
logger.debug(f"Token usage: {usage.to_dict()}")
|
logger.debug(
|
||||||
print(f"Token usage: {usage.to_dict()}")
|
f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total"
|
||||||
|
)
|
||||||
|
self.usage_tracker.record_usage(
|
||||||
|
model=f"{self.provider}/{self.model}",
|
||||||
|
input_tokens=usage.input_tokens,
|
||||||
|
output_tokens=usage.output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def execute_tool(
|
def execute_tool(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class OpenAIProvider(BaseLLMProvider):
|
class OpenAIProvider(BaseLLMProvider):
|
||||||
"""OpenAI LLM provider with streaming and tool support."""
|
"""OpenAI LLM provider with streaming and tool support."""
|
||||||
|
|
||||||
|
provider = "openai"
|
||||||
|
|
||||||
# Models that use max_completion_tokens instead of max_tokens
|
# Models that use max_completion_tokens instead of max_tokens
|
||||||
# These are reasoning models with different parameter requirements
|
# These are reasoning models with different parameter requirements
|
||||||
NON_REASONING_MODELS = {"gpt-4o"}
|
NON_REASONING_MODELS = {"gpt-4o"}
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class RedisUsageTracker(UsageTracker):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
configs: dict[str, RateLimitConfig],
|
configs: dict[str, RateLimitConfig] | None = None,
|
||||||
default_config: RateLimitConfig | None = None,
|
default_config: RateLimitConfig | None = None,
|
||||||
*,
|
*,
|
||||||
redis_client: RedisClientProtocol | None = None,
|
redis_client: RedisClientProtocol | None = None,
|
||||||
@ -41,12 +41,7 @@ class RedisUsageTracker(UsageTracker):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(configs=configs, default_config=default_config)
|
super().__init__(configs=configs, default_config=default_config)
|
||||||
if redis_client is None:
|
if redis_client is None:
|
||||||
redis_client = redis.Redis(
|
redis_client = redis.Redis.from_url(settings.REDIS_URL)
|
||||||
host=settings.REDIS_HOST,
|
|
||||||
port=int(settings.REDIS_PORT),
|
|
||||||
db=int(settings.REDIS_DB),
|
|
||||||
decode_responses=False,
|
|
||||||
)
|
|
||||||
self._redis = redis_client
|
self._redis = redis_client
|
||||||
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
|
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
|
||||||
self._key_prefix = prefix.rstrip(":")
|
self._key_prefix = prefix.rstrip(":")
|
||||||
|
|||||||
@ -7,6 +7,8 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RateLimitConfig:
|
class RateLimitConfig:
|
||||||
@ -111,12 +113,14 @@ class UsageBreakdown:
|
|||||||
|
|
||||||
def split_model_key(model: str) -> tuple[str, str]:
|
def split_model_key(model: str) -> tuple[str, str]:
|
||||||
if "/" not in model:
|
if "/" not in model:
|
||||||
raise ValueError("model must be formatted as '<provider>/<model_name>'")
|
raise ValueError(
|
||||||
|
f"model must be formatted as '<provider>/<model_name>': got '{model}'"
|
||||||
|
)
|
||||||
|
|
||||||
provider, model_name = model.split("/", maxsplit=1)
|
provider, model_name = model.split("/", maxsplit=1)
|
||||||
if not provider or not model_name:
|
if not provider or not model_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model must include both provider and model name separated by '/'"
|
f"model must include both provider and model name separated by '/': got '{model}'"
|
||||||
)
|
)
|
||||||
return provider, model_name
|
return provider, model_name
|
||||||
|
|
||||||
@ -126,11 +130,15 @@ class UsageTracker:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
configs: dict[str, RateLimitConfig],
|
configs: dict[str, RateLimitConfig] | None = None,
|
||||||
default_config: RateLimitConfig | None = None,
|
default_config: RateLimitConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._configs = configs
|
self._configs = configs or {}
|
||||||
self._default_config = default_config
|
self._default_config = default_config or RateLimitConfig(
|
||||||
|
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
|
||||||
|
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||||
|
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||||
|
)
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -213,15 +221,14 @@ class UsageTracker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
split_model_key(model)
|
split_model_key(model)
|
||||||
key = model
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
config = self._get_config(key)
|
config = self._get_config(model)
|
||||||
if config is None:
|
if config is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
state = self.get_state(key)
|
state = self.get_state(model)
|
||||||
self._prune_expired_events(state, config, now=timestamp)
|
self._prune_expired_events(state, config, now=timestamp)
|
||||||
self.save_state(key, state)
|
self.save_state(model, state)
|
||||||
|
|
||||||
if config.max_total_tokens is None:
|
if config.max_total_tokens is None:
|
||||||
total_remaining = None
|
total_remaining = None
|
||||||
@ -253,8 +260,8 @@ class UsageTracker:
|
|||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||||
for key, state in self.iter_state_items():
|
for model, state in self.iter_state_items():
|
||||||
prov, model_name = split_model_key(key)
|
prov, model_name = split_model_key(model)
|
||||||
if provider and provider != prov:
|
if provider and provider != prov:
|
||||||
continue
|
continue
|
||||||
if model and model != model_name:
|
if model and model != model_name:
|
||||||
@ -296,8 +303,8 @@ class UsageTracker:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def _get_config(self, key: str) -> RateLimitConfig | None:
|
def _get_config(self, model: str) -> RateLimitConfig | None:
|
||||||
return self._configs.get(key) or self._default_config
|
return self._configs.get(model) or self._default_config
|
||||||
|
|
||||||
def _prune_expired_events(
|
def _prune_expired_events(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -31,34 +31,25 @@ def make_db_url(
|
|||||||
|
|
||||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||||
|
|
||||||
|
# Redis settings
|
||||||
|
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||||
|
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||||
|
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||||
|
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||||
|
if REDIS_PASSWORD:
|
||||||
|
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||||
|
else:
|
||||||
|
REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||||
|
|
||||||
# Broker settings
|
# Broker settings
|
||||||
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
||||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "")
|
||||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD)
|
||||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
|
||||||
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
|
|
||||||
CELERY_BROKER_USER = os.getenv(
|
|
||||||
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
|
||||||
)
|
|
||||||
CELERY_BROKER_PASSWORD = os.getenv(
|
|
||||||
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
|
||||||
if not CELERY_BROKER_HOST:
|
|
||||||
if CELERY_BROKER_TYPE == "amqp":
|
|
||||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
|
||||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
|
||||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
|
||||||
elif CELERY_BROKER_TYPE == "redis":
|
|
||||||
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
|
||||||
|
|
||||||
|
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}"
|
||||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||||
|
|
||||||
|
|
||||||
# File storage settings
|
# File storage settings
|
||||||
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
||||||
EBOOK_STORAGE_DIR = pathlib.Path(
|
EBOOK_STORAGE_DIR = pathlib.Path(
|
||||||
@ -149,6 +140,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
|
|||||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||||
|
|
||||||
|
DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int(
|
||||||
|
os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30)
|
||||||
|
)
|
||||||
|
DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int(
|
||||||
|
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000)
|
||||||
|
)
|
||||||
|
DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int(
|
||||||
|
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000)
|
||||||
|
)
|
||||||
|
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
|
||||||
|
|
||||||
|
|
||||||
# Search settings
|
# Search settings
|
||||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||||
|
|||||||
@ -203,7 +203,7 @@ class MessageCollector(commands.Bot):
|
|||||||
async def setup_hook(self):
|
async def setup_hook(self):
|
||||||
"""Register slash commands when the bot is ready."""
|
"""Register slash commands when the bot is ready."""
|
||||||
|
|
||||||
register_slash_commands(self)
|
register_slash_commands(self, name=self.user.name)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""Called when bot connects to Discord"""
|
"""Called when bot connects to Discord"""
|
||||||
|
|||||||
@ -41,8 +41,13 @@ class CommandContext:
|
|||||||
CommandHandler = Callable[..., CommandResponse]
|
CommandHandler = Callable[..., CommandResponse]
|
||||||
|
|
||||||
|
|
||||||
def register_slash_commands(bot: discord.Client) -> None:
|
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||||
"""Register the collector slash commands on the provided bot."""
|
"""Register the collector slash commands on the provided bot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot: Discord bot client
|
||||||
|
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
|
||||||
|
"""
|
||||||
|
|
||||||
if getattr(bot, "_memory_commands_registered", False):
|
if getattr(bot, "_memory_commands_registered", False):
|
||||||
return
|
return
|
||||||
@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
|
|
||||||
tree = bot.tree
|
tree = bot.tree
|
||||||
|
|
||||||
@tree.command(name="memory_prompt", description="Show the current system prompt")
|
@tree.command(
|
||||||
|
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||||
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
scope="Which configuration to inspect",
|
scope="Which configuration to inspect",
|
||||||
user="Target user when the scope is 'user'",
|
user="Target user when the scope is 'user'",
|
||||||
)
|
)
|
||||||
async def prompt_command(
|
async def show_prompt_command(
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
scope: ScopeLiteral,
|
scope: ScopeLiteral,
|
||||||
user: discord.User | None = None,
|
user: discord.User | None = None,
|
||||||
@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tree.command(
|
@tree.command(
|
||||||
name="memory_chattiness",
|
name=f"{name}_set_prompt",
|
||||||
description="Show or update the chattiness threshold for the target",
|
description="Set the system prompt for the target",
|
||||||
|
)
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
scope="Which configuration to modify",
|
||||||
|
prompt="The system prompt to set",
|
||||||
|
user="Target user when the scope is 'user'",
|
||||||
|
)
|
||||||
|
async def set_prompt_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
prompt: str,
|
||||||
|
user: discord.User | None = None,
|
||||||
|
) -> None:
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_set_prompt,
|
||||||
|
target_user=user,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tree.command(
|
||||||
|
name=f"{name}_chattiness",
|
||||||
|
description="Show or update the chattiness for the target",
|
||||||
)
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
scope="Which configuration to inspect",
|
scope="Which configuration to inspect",
|
||||||
value="Optional new threshold value between 0 and 100",
|
value="Optional new chattiness value between 0 and 100",
|
||||||
user="Target user when the scope is 'user'",
|
user="Target user when the scope is 'user'",
|
||||||
)
|
)
|
||||||
async def chattiness_command(
|
async def chattiness_command(
|
||||||
@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tree.command(
|
@tree.command(
|
||||||
name="memory_ignore",
|
name=f"{name}_ignore",
|
||||||
description="Toggle whether the bot should ignore messages for the target",
|
description="Toggle whether the bot should ignore messages for the target",
|
||||||
)
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
ignore_enabled=enabled,
|
ignore_enabled=enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tree.command(name="memory_summary", description="Show the stored summary for the target")
|
@tree.command(
|
||||||
|
name=f"{name}_show_summary",
|
||||||
|
description="Show the stored summary for the target",
|
||||||
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
scope="Which configuration to inspect",
|
scope="Which configuration to inspect",
|
||||||
user="Target user when the scope is 'user'",
|
user="Target user when the scope is 'user'",
|
||||||
@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_set_prompt(
|
||||||
|
context: CommandContext,
|
||||||
|
*,
|
||||||
|
prompt: str,
|
||||||
|
) -> CommandResponse:
|
||||||
|
setattr(context.target, "system_prompt", prompt)
|
||||||
|
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"Updated system prompt for {context.display_name}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_chattiness(
|
def handle_chattiness(
|
||||||
context: CommandContext,
|
context: CommandContext,
|
||||||
*,
|
*,
|
||||||
@ -347,20 +392,22 @@ def handle_chattiness(
|
|||||||
if value is None:
|
if value is None:
|
||||||
return CommandResponse(
|
return CommandResponse(
|
||||||
content=(
|
content=(
|
||||||
f"Chattiness threshold for {context.display_name}: "
|
f"Chattiness for {context.display_name}: "
|
||||||
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not 0 <= value <= 100:
|
if not 0 <= value <= 100:
|
||||||
raise CommandError("Chattiness threshold must be between 0 and 100.")
|
raise CommandError("Chattiness must be between 0 and 100.")
|
||||||
|
|
||||||
setattr(model, "chattiness_threshold", value)
|
setattr(model, "chattiness_threshold", value)
|
||||||
|
|
||||||
return CommandResponse(
|
return CommandResponse(
|
||||||
content=(
|
content=(
|
||||||
f"Updated chattiness threshold for {context.display_name} "
|
f"Updated chattiness for {context.display_name} to {value}."
|
||||||
f"to {value}."
|
"\n"
|
||||||
|
"This can be treated as how much you want the bot to pipe up by itself, as a percentage, "
|
||||||
|
"where 0 is never and 100 is always."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -128,12 +128,19 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
"update_server_summary",
|
"update_server_summary",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
print("response", response)
|
||||||
if not response:
|
if not response:
|
||||||
return False
|
return False
|
||||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
return int(res.group(1)) > message.chattiness_threshold
|
print(
|
||||||
|
"parsed",
|
||||||
|
int(res.group(1)),
|
||||||
|
message.chattiness_threshold,
|
||||||
|
100 - message.chattiness_threshold,
|
||||||
|
)
|
||||||
|
return int(res.group(1)) > 100 - message.chattiness_threshold
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user