mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 17:11:19 +01:00
Compare commits
10 Commits
8af07f0dac
...
e95a082147
| Author | SHA1 | Date | |
|---|---|---|---|
| e95a082147 | |||
| c42513100b | |||
| a5bc53326d | |||
| 131427255a | |||
|
|
ff3ca4f109 | ||
| 3b216953ab | |||
|
|
d7e403fb83 | ||
| 57145ac7b4 | |||
|
|
814090dccb | ||
|
|
9639fa3dd7 |
114
db/migrations/versions/20251101_203810_allow_no_chattiness.py
Normal file
114
db/migrations/versions/20251101_203810_allow_no_chattiness.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
"""allow no chattiness
|
||||||
|
|
||||||
|
Revision ID: 2024235e37e7
|
||||||
|
Revises: 7dc03dbf184c
|
||||||
|
Create Date: 2025-11-01 20:38:10.849651
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "2024235e37e7"
|
||||||
|
down_revision: Union[str, None] = "7dc03dbf184c"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
"discord_channels",
|
||||||
|
"chattiness_threshold",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=True,
|
||||||
|
existing_server_default=sa.text("50"),
|
||||||
|
)
|
||||||
|
op.drop_column("discord_channels", "track_messages")
|
||||||
|
op.alter_column(
|
||||||
|
"discord_servers",
|
||||||
|
"chattiness_threshold",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=True,
|
||||||
|
existing_server_default=sa.text("50"),
|
||||||
|
)
|
||||||
|
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
|
||||||
|
op.create_index(
|
||||||
|
"discord_servers_active_idx",
|
||||||
|
"discord_servers",
|
||||||
|
["ignore_messages", "last_sync_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.drop_column("discord_servers", "track_messages")
|
||||||
|
op.alter_column(
|
||||||
|
"discord_users",
|
||||||
|
"chattiness_threshold",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=True,
|
||||||
|
existing_server_default=sa.text("50"),
|
||||||
|
)
|
||||||
|
op.drop_column("discord_users", "track_messages")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"discord_users",
|
||||||
|
sa.Column(
|
||||||
|
"track_messages",
|
||||||
|
sa.BOOLEAN(),
|
||||||
|
server_default=sa.text("true"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"discord_users",
|
||||||
|
"chattiness_threshold",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=False,
|
||||||
|
existing_server_default=sa.text("50"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"discord_servers",
|
||||||
|
sa.Column(
|
||||||
|
"track_messages",
|
||||||
|
sa.BOOLEAN(),
|
||||||
|
server_default=sa.text("true"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
|
||||||
|
op.create_index(
|
||||||
|
"discord_servers_active_idx",
|
||||||
|
"discord_servers",
|
||||||
|
["track_messages", "last_sync_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"discord_servers",
|
||||||
|
"chattiness_threshold",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=False,
|
||||||
|
existing_server_default=sa.text("50"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"discord_channels",
|
||||||
|
sa.Column(
|
||||||
|
"track_messages",
|
||||||
|
sa.BOOLEAN(),
|
||||||
|
server_default=sa.text("true"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"discord_channels",
|
||||||
|
"chattiness_threshold",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=False,
|
||||||
|
existing_server_default=sa.text("50"),
|
||||||
|
)
|
||||||
@ -1,5 +1,3 @@
|
|||||||
version: "3.9"
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------- networks
|
# --------------------------------------------------------------------- networks
|
||||||
networks:
|
networks:
|
||||||
kbnet:
|
kbnet:
|
||||||
@ -26,7 +24,7 @@ x-common-env: &env
|
|||||||
REDIS_HOST: redis
|
REDIS_HOST: redis
|
||||||
REDIS_PORT: 6379
|
REDIS_PORT: 6379
|
||||||
REDIS_DB: 0
|
REDIS_DB: 0
|
||||||
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
|
REDIS_PASSWORD: ${REDIS_PASSWORD}
|
||||||
QDRANT_HOST: qdrant
|
QDRANT_HOST: qdrant
|
||||||
DB_HOST: postgres
|
DB_HOST: postgres
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
@ -107,11 +105,11 @@ services:
|
|||||||
image: redis:7.2-alpine
|
image: redis:7.2-alpine
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [ kbnet ]
|
networks: [ kbnet ]
|
||||||
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${CELERY_BROKER_PASSWORD}"]
|
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${REDIS_PASSWORD}"]
|
||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data:rw
|
- redis_data:/data:rw
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: [ "CMD", "redis-cli", "--pass", "${CELERY_BROKER_PASSWORD}", "ping" ]
|
test: [ "CMD", "redis-cli", "--pass", "${REDIS_PASSWORD}", "ping" ]
|
||||||
interval: 15s
|
interval: 15s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
@ -175,7 +173,7 @@ services:
|
|||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
|
QUEUES: "backup,email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
|
||||||
|
|
||||||
ingest-hub:
|
ingest-hub:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
@ -196,6 +194,22 @@ services:
|
|||||||
- /var/run/supervisor
|
- /var/run/supervisor
|
||||||
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
||||||
|
|
||||||
|
# ------------------------------------------------------------ database backups
|
||||||
|
backup:
|
||||||
|
image: postgres:15 # Has pg_dump, wget, curl
|
||||||
|
networks: [kbnet]
|
||||||
|
depends_on: [postgres, qdrant]
|
||||||
|
env_file: [ .env ]
|
||||||
|
environment:
|
||||||
|
<<: *worker-env
|
||||||
|
secrets: [postgres_password]
|
||||||
|
volumes:
|
||||||
|
- ./tools/backup_databases.sh:/backup.sh:ro
|
||||||
|
entrypoint: ["/bin/bash"]
|
||||||
|
command: ["/backup.sh"]
|
||||||
|
profiles: [backup] # Only start when explicitly called
|
||||||
|
security_opt: ["no-new-privileges=true"]
|
||||||
|
|
||||||
# ------------------------------------------------------------ watchtower (auto-update)
|
# ------------------------------------------------------------ watchtower (auto-update)
|
||||||
# watchtower:
|
# watchtower:
|
||||||
# image: containrrr/watchtower
|
# image: containrrr/watchtower
|
||||||
|
|||||||
@ -16,7 +16,7 @@ RUN apt-get update && apt-get install -y \
|
|||||||
COPY requirements ./requirements/
|
COPY requirements ./requirements/
|
||||||
COPY setup.py ./
|
COPY setup.py ./
|
||||||
RUN mkdir src
|
RUN mkdir src
|
||||||
RUN pip install -e ".[common]"
|
RUN pip install -e ".[workers]"
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
COPY src/ ./src/
|
COPY src/ ./src/
|
||||||
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
|
|||||||
git config --global user.name "${GIT_USER_NAME}"
|
git config --global user.name "${GIT_USER_NAME}"
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
|
ENV QUEUES="backup,ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["./entry.sh"]
|
ENTRYPOINT ["./entry.sh"]
|
||||||
@ -10,3 +10,4 @@ openai==2.3.0
|
|||||||
# Pin the httpx version, as newer versions break the anthropic client
|
# Pin the httpx version, as newer versions break the anthropic client
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
celery[redis,sqs]==5.3.6
|
celery[redis,sqs]==5.3.6
|
||||||
|
cryptography==43.0.0
|
||||||
2
requirements/requirements-workers.txt
Normal file
2
requirements/requirements-workers.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
boto3
|
||||||
|
awscli==1.42.64
|
||||||
5
setup.py
5
setup.py
@ -18,6 +18,7 @@ parsers_requires = read_requirements("requirements-parsers.txt")
|
|||||||
api_requires = read_requirements("requirements-api.txt")
|
api_requires = read_requirements("requirements-api.txt")
|
||||||
dev_requires = read_requirements("requirements-dev.txt")
|
dev_requires = read_requirements("requirements-dev.txt")
|
||||||
ingesters_requires = read_requirements("requirements-ingesters.txt")
|
ingesters_requires = read_requirements("requirements-ingesters.txt")
|
||||||
|
workers_requires = read_requirements("requirements-workers.txt")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="memory",
|
name="memory",
|
||||||
@ -30,10 +31,12 @@ setup(
|
|||||||
"common": common_requires + parsers_requires,
|
"common": common_requires + parsers_requires,
|
||||||
"dev": dev_requires,
|
"dev": dev_requires,
|
||||||
"ingesters": common_requires + parsers_requires + ingesters_requires,
|
"ingesters": common_requires + parsers_requires + ingesters_requires,
|
||||||
|
"workers": common_requires + parsers_requires + workers_requires,
|
||||||
"all": api_requires
|
"all": api_requires
|
||||||
+ common_requires
|
+ common_requires
|
||||||
+ dev_requires
|
+ dev_requires
|
||||||
+ parsers_requires
|
+ parsers_requires
|
||||||
+ ingesters_requires,
|
+ ingesters_requires
|
||||||
|
+ workers_requires,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -290,6 +290,7 @@ class ScheduledLLMCallAdmin(ModelView, model=ScheduledLLMCall):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
|
column_sortable_list = ["executed_at", "scheduled_time", "created_at", "updated_at"]
|
||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
|
|||||||
@ -13,6 +13,7 @@ NOTES_ROOT = "memory.workers.tasks.notes"
|
|||||||
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||||
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||||
DISCORD_ROOT = "memory.workers.tasks.discord"
|
DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||||
|
BACKUP_ROOT = "memory.workers.tasks.backup"
|
||||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||||
@ -53,13 +54,25 @@ SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
|||||||
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
|
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
|
||||||
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
|
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
|
||||||
|
|
||||||
|
# Backup tasks
|
||||||
|
BACKUP_PATH = f"{BACKUP_ROOT}.backup_path"
|
||||||
|
BACKUP_ALL = f"{BACKUP_ROOT}.backup_all"
|
||||||
|
|
||||||
|
|
||||||
def get_broker_url() -> str:
|
def get_broker_url() -> str:
|
||||||
protocol = settings.CELERY_BROKER_TYPE
|
protocol = settings.CELERY_BROKER_TYPE
|
||||||
user = safequote(settings.CELERY_BROKER_USER)
|
user = safequote(settings.CELERY_BROKER_USER)
|
||||||
password = safequote(settings.CELERY_BROKER_PASSWORD)
|
password = safequote(settings.CELERY_BROKER_PASSWORD or "")
|
||||||
host = settings.CELERY_BROKER_HOST
|
host = settings.CELERY_BROKER_HOST
|
||||||
return f"{protocol}://{user}:{password}@{host}"
|
|
||||||
|
if password:
|
||||||
|
url = f"{protocol}://{user}:{password}@{host}"
|
||||||
|
else:
|
||||||
|
url = f"{protocol}://{host}"
|
||||||
|
|
||||||
|
if protocol == "redis":
|
||||||
|
url += f"/{settings.REDIS_DB}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
app = Celery(
|
app = Celery(
|
||||||
@ -91,6 +104,7 @@ app.conf.update(
|
|||||||
f"{SCHEDULED_CALLS_ROOT}.*": {
|
f"{SCHEDULED_CALLS_ROOT}.*": {
|
||||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
||||||
},
|
},
|
||||||
|
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from memory.common.db.models.base import Base
|
|||||||
|
|
||||||
|
|
||||||
class MessageProcessor:
|
class MessageProcessor:
|
||||||
track_messages = Column(Boolean, nullable=False, server_default="true")
|
|
||||||
ignore_messages = Column(Boolean, nullable=True, default=False)
|
ignore_messages = Column(Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
@ -35,8 +34,7 @@ class MessageProcessor:
|
|||||||
)
|
)
|
||||||
chattiness_threshold = Column(
|
chattiness_threshold = Column(
|
||||||
Integer,
|
Integer,
|
||||||
nullable=False,
|
nullable=True,
|
||||||
default=50,
|
|
||||||
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,7 +88,7 @@ class DiscordServer(Base, MessageProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
|
Index("discord_servers_active_idx", "ignore_messages", "last_sync_at"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -138,6 +138,8 @@ class DiscordBotUser(BotUser):
|
|||||||
email: str,
|
email: str,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
) -> "DiscordBotUser":
|
) -> "DiscordBotUser":
|
||||||
|
if not discord_users:
|
||||||
|
raise ValueError("discord_users must be provided")
|
||||||
bot = super().create_with_api_key(name, email, api_key)
|
bot = super().create_with_api_key(name, email, api_key)
|
||||||
bot.discord_users = discord_users
|
bot.discord_users = discord_users
|
||||||
return bot
|
return bot
|
||||||
|
|||||||
@ -5,9 +5,10 @@ Simple HTTP client that communicates with the Discord collector's API server.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import requests
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -20,12 +21,12 @@ def get_api_url() -> str:
|
|||||||
return f"http://{host}:{port}"
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def send_dm(user_identifier: str, message: str) -> bool:
|
def send_dm(bot_id: int, user_identifier: str, message: str) -> bool:
|
||||||
"""Send a DM via the Discord collector API"""
|
"""Send a DM via the Discord collector API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_dm",
|
f"{get_api_url()}/send_dm",
|
||||||
json={"user": user_identifier, "message": message},
|
json={"bot_id": bot_id, "user": user_identifier, "message": message},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -37,12 +38,33 @@ def send_dm(user_identifier: str, message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def send_to_channel(channel_name: str, message: str) -> bool:
|
def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
|
||||||
|
"""Trigger typing indicator for a DM via the Discord collector API"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{get_api_url()}/typing/dm",
|
||||||
|
json={"bot_id": bot_id, "user": user_identifier},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
||||||
"""Send a DM via the Discord collector API"""
|
"""Send a DM via the Discord collector API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={"channel_name": channel_name, "message": message},
|
json={
|
||||||
|
"bot_id": bot_id,
|
||||||
|
"channel_name": channel_name,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -55,12 +77,33 @@ def send_to_channel(channel_name: str, message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(channel_name: str, message: str) -> bool:
|
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
||||||
|
"""Trigger typing indicator for a channel via the Discord collector API"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{get_api_url()}/typing/channel",
|
||||||
|
json={"bot_id": bot_id, "channel_name": channel_name},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
||||||
"""Send a message to a channel via the Discord collector API"""
|
"""Send a message to a channel via the Discord collector API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={"channel_name": channel_name, "message": message},
|
json={
|
||||||
|
"bot_id": bot_id,
|
||||||
|
"channel_name": channel_name,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -72,19 +115,22 @@ def broadcast_message(channel_name: str, message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_collector_healthy() -> bool:
|
def is_collector_healthy(bot_id: int) -> bool:
|
||||||
"""Check if the Discord collector is running and healthy"""
|
"""Check if the Discord collector is running and healthy"""
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result.get("status") == "healthy"
|
bot_status = result.get(str(bot_id))
|
||||||
|
if not isinstance(bot_status, dict):
|
||||||
|
return False
|
||||||
|
return bool(bot_status.get("connected"))
|
||||||
|
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def refresh_discord_metadata() -> dict[str, int] | None:
|
def refresh_discord_metadata() -> dict[str, Any] | None:
|
||||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
||||||
@ -96,24 +142,24 @@ def refresh_discord_metadata() -> dict[str, int] | None:
|
|||||||
|
|
||||||
|
|
||||||
# Convenience functions
|
# Convenience functions
|
||||||
def send_error_message(message: str) -> bool:
|
def send_error_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send an error message to the error channel"""
|
"""Send an error message to the error channel"""
|
||||||
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def send_activity_message(message: str) -> bool:
|
def send_activity_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send an activity message to the activity channel"""
|
"""Send an activity message to the activity channel"""
|
||||||
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def send_discovery_message(message: str) -> bool:
|
def send_discovery_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send a discovery message to the discovery channel"""
|
"""Send a discovery message to the discovery channel"""
|
||||||
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def send_chat_message(message: str) -> bool:
|
def send_chat_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send a chat message to the chat channel"""
|
"""Send a chat message to the chat channel"""
|
||||||
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def notify_task_failure(
|
def notify_task_failure(
|
||||||
@ -122,6 +168,7 @@ def notify_task_failure(
|
|||||||
task_args: tuple = (),
|
task_args: tuple = (),
|
||||||
task_kwargs: dict[str, Any] | None = None,
|
task_kwargs: dict[str, Any] | None = None,
|
||||||
traceback_str: str | None = None,
|
traceback_str: str | None = None,
|
||||||
|
bot_id: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send a task failure notification to Discord.
|
Send a task failure notification to Discord.
|
||||||
@ -137,6 +184,15 @@ def notify_task_failure(
|
|||||||
logger.debug("Discord notifications disabled")
|
logger.debug("Discord notifications disabled")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if bot_id is None:
|
||||||
|
bot_id = settings.DISCORD_BOT_ID
|
||||||
|
|
||||||
|
if not bot_id:
|
||||||
|
logger.debug(
|
||||||
|
"No Discord bot ID provided for task failure notification; skipping"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
||||||
message += f"**Error:** {error_message[:500]}\n"
|
message += f"**Error:** {error_message[:500]}\n"
|
||||||
|
|
||||||
@ -150,7 +206,7 @@ def notify_task_failure(
|
|||||||
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
send_error_message(message)
|
send_error_message(bot_id, message)
|
||||||
logger.info(f"Discord error notification sent for task: {task_name}")
|
logger.info(f"Discord error notification sent for task: {task_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send Discord notification: {e}")
|
logger.error(f"Failed to send Discord notification: {e}")
|
||||||
|
|||||||
@ -24,13 +24,6 @@ from memory.common.llms.base import (
|
|||||||
)
|
)
|
||||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||||
from memory.common.llms.openai_provider import OpenAIProvider
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
from memory.common.llms.usage_tracker import (
|
|
||||||
InMemoryUsageTracker,
|
|
||||||
RateLimitConfig,
|
|
||||||
TokenAllowance,
|
|
||||||
UsageBreakdown,
|
|
||||||
UsageTracker,
|
|
||||||
)
|
|
||||||
from memory.common import tokens
|
from memory.common import tokens
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -49,11 +42,6 @@ __all__ = [
|
|||||||
"StreamEvent",
|
"StreamEvent",
|
||||||
"LLMSettings",
|
"LLMSettings",
|
||||||
"create_provider",
|
"create_provider",
|
||||||
"InMemoryUsageTracker",
|
|
||||||
"RateLimitConfig",
|
|
||||||
"TokenAllowance",
|
|
||||||
"UsageBreakdown",
|
|
||||||
"UsageTracker",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -93,28 +81,3 @@ def truncate(content: str, target_tokens: int) -> str:
|
|||||||
if len(content) > target_chars:
|
if len(content) > target_chars:
|
||||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
# bla = 1
|
|
||||||
# from memory.common.llms import *
|
|
||||||
# from memory.common.llms.tools.discord import make_discord_tools
|
|
||||||
# from memory.common.db.connection import make_session
|
|
||||||
# from memory.common.db.models import *
|
|
||||||
|
|
||||||
# model = "anthropic/claude-sonnet-4-5"
|
|
||||||
# provider = create_provider(model=model)
|
|
||||||
# with make_session() as session:
|
|
||||||
# bot = session.query(DiscordBotUser).first()
|
|
||||||
# server = session.query(DiscordServer).first()
|
|
||||||
# channel = server.channels[0]
|
|
||||||
# tools = make_discord_tools(bot, None, channel, model)
|
|
||||||
|
|
||||||
# def demo(msg: str):
|
|
||||||
# messages = [
|
|
||||||
# Message(
|
|
||||||
# role=MessageRole.USER,
|
|
||||||
# content=msg,
|
|
||||||
# )
|
|
||||||
# ]
|
|
||||||
# for m in provider.stream_with_tools(messages, tools):
|
|
||||||
# print(m)
|
|
||||||
|
|||||||
@ -23,6 +23,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class AnthropicProvider(BaseLLMProvider):
|
class AnthropicProvider(BaseLLMProvider):
|
||||||
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
||||||
|
|
||||||
|
provider = "anthropic"
|
||||||
|
|
||||||
# Models that support extended thinking
|
# Models that support extended thinking
|
||||||
THINKING_MODELS = {
|
THINKING_MODELS = {
|
||||||
"claude-opus-4",
|
"claude-opus-4",
|
||||||
@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
Usage(
|
Usage(
|
||||||
input_tokens=usage.input_tokens,
|
input_tokens=usage.input_tokens,
|
||||||
output_tokens=usage.output_tokens,
|
output_tokens=usage.output_tokens,
|
||||||
total_tokens=usage.total_tokens,
|
total_tokens=usage.input_tokens + usage.output_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
||||||
|
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -204,7 +205,11 @@ class LLMSettings:
|
|||||||
class BaseLLMProvider(ABC):
|
class BaseLLMProvider(ABC):
|
||||||
"""Base class for LLM providers."""
|
"""Base class for LLM providers."""
|
||||||
|
|
||||||
def __init__(self, api_key: str, model: str):
|
provider: str = ""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the LLM provider.
|
Initialize the LLM provider.
|
||||||
|
|
||||||
@ -215,6 +220,7 @@ class BaseLLMProvider(ABC):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self._client: Any = None
|
self._client: Any = None
|
||||||
|
self.usage_tracker: UsageTracker = usage_tracker or RedisUsageTracker()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _initialize_client(self) -> Any:
|
def _initialize_client(self) -> Any:
|
||||||
@ -230,8 +236,14 @@ class BaseLLMProvider(ABC):
|
|||||||
|
|
||||||
def log_usage(self, usage: Usage):
|
def log_usage(self, usage: Usage):
|
||||||
"""Log usage data."""
|
"""Log usage data."""
|
||||||
logger.debug(f"Token usage: {usage.to_dict()}")
|
logger.debug(
|
||||||
print(f"Token usage: {usage.to_dict()}")
|
f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total"
|
||||||
|
)
|
||||||
|
self.usage_tracker.record_usage(
|
||||||
|
model=f"{self.provider}/{self.model}",
|
||||||
|
input_tokens=usage.input_tokens,
|
||||||
|
output_tokens=usage.output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def execute_tool(
|
def execute_tool(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class OpenAIProvider(BaseLLMProvider):
|
class OpenAIProvider(BaseLLMProvider):
|
||||||
"""OpenAI LLM provider with streaming and tool support."""
|
"""OpenAI LLM provider with streaming and tool support."""
|
||||||
|
|
||||||
|
provider = "openai"
|
||||||
|
|
||||||
# Models that use max_completion_tokens instead of max_tokens
|
# Models that use max_completion_tokens instead of max_tokens
|
||||||
# These are reasoning models with different parameter requirements
|
# These are reasoning models with different parameter requirements
|
||||||
NON_REASONING_MODELS = {"gpt-4o"}
|
NON_REASONING_MODELS = {"gpt-4o"}
|
||||||
|
|||||||
17
src/memory/common/llms/usage/__init__.py
Normal file
17
src/memory/common/llms/usage/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from memory.common.llms.usage.redis_usage_tracker import RedisUsageTracker
|
||||||
|
from memory.common.llms.usage.usage_tracker import (
|
||||||
|
InMemoryUsageTracker,
|
||||||
|
RateLimitConfig,
|
||||||
|
TokenAllowance,
|
||||||
|
UsageBreakdown,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"InMemoryUsageTracker",
|
||||||
|
"RateLimitConfig",
|
||||||
|
"RedisUsageTracker",
|
||||||
|
"TokenAllowance",
|
||||||
|
"UsageBreakdown",
|
||||||
|
"UsageTracker",
|
||||||
|
]
|
||||||
83
src/memory/common/llms/usage/redis_usage_tracker.py
Normal file
83
src/memory/common/llms/usage/redis_usage_tracker.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
"""Redis-backed usage tracker implementation."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Iterable, Protocol
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.llms.usage.usage_tracker import (
|
||||||
|
RateLimitConfig,
|
||||||
|
UsageState,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisClientProtocol(Protocol):
|
||||||
|
def get(self, key: str) -> Any: # pragma: no cover - Protocol definition
|
||||||
|
...
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self, key: str, value: Any
|
||||||
|
) -> Any: # pragma: no cover - Protocol definition
|
||||||
|
...
|
||||||
|
|
||||||
|
def scan_iter(
|
||||||
|
self, match: str
|
||||||
|
) -> Iterable[Any]: # pragma: no cover - Protocol definition
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class RedisUsageTracker(UsageTracker):
|
||||||
|
"""Tracks LLM usage for providers and models using Redis for persistence."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: dict[str, RateLimitConfig] | None = None,
|
||||||
|
default_config: RateLimitConfig | None = None,
|
||||||
|
*,
|
||||||
|
redis_client: RedisClientProtocol | None = None,
|
||||||
|
key_prefix: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(configs=configs, default_config=default_config)
|
||||||
|
if redis_client is None:
|
||||||
|
redis_client = redis.Redis.from_url(settings.REDIS_URL)
|
||||||
|
self._redis = redis_client
|
||||||
|
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
|
||||||
|
self._key_prefix = prefix.rstrip(":")
|
||||||
|
|
||||||
|
def get_state(self, model: str) -> UsageState:
|
||||||
|
redis_key = self._format_key(model)
|
||||||
|
payload = self._redis.get(redis_key)
|
||||||
|
if not payload:
|
||||||
|
return UsageState()
|
||||||
|
if isinstance(payload, bytes):
|
||||||
|
payload = payload.decode()
|
||||||
|
return UsageState.from_payload(json.loads(payload))
|
||||||
|
|
||||||
|
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||||
|
pattern = f"{self._key_prefix}:*"
|
||||||
|
for redis_key in self._redis.scan_iter(match=pattern):
|
||||||
|
key = self._ensure_str(redis_key)
|
||||||
|
payload = self._redis.get(key)
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
if isinstance(payload, bytes):
|
||||||
|
payload = payload.decode()
|
||||||
|
state = UsageState.from_payload(json.loads(payload))
|
||||||
|
yield key[len(self._key_prefix) + 1 :], state
|
||||||
|
|
||||||
|
def save_state(self, model: str, state: UsageState) -> None:
|
||||||
|
redis_key = self._format_key(model)
|
||||||
|
self._redis.set(
|
||||||
|
redis_key, json.dumps(state.to_payload(), separators=(",", ":"))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_key(self, model: str) -> str:
|
||||||
|
return f"{self._key_prefix}:{model}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_str(value: Any) -> str:
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return value.decode()
|
||||||
|
return str(value)
|
||||||
@ -5,6 +5,9 @@ from collections.abc import Iterable
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -45,6 +48,40 @@ class UsageState:
|
|||||||
lifetime_input_tokens: int = 0
|
lifetime_input_tokens: int = 0
|
||||||
lifetime_output_tokens: int = 0
|
lifetime_output_tokens: int = 0
|
||||||
|
|
||||||
|
def to_payload(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"timestamp": event.timestamp.isoformat(),
|
||||||
|
"input_tokens": event.input_tokens,
|
||||||
|
"output_tokens": event.output_tokens,
|
||||||
|
}
|
||||||
|
for event in self.events
|
||||||
|
],
|
||||||
|
"window_input_tokens": self.window_input_tokens,
|
||||||
|
"window_output_tokens": self.window_output_tokens,
|
||||||
|
"lifetime_input_tokens": self.lifetime_input_tokens,
|
||||||
|
"lifetime_output_tokens": self.lifetime_output_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_payload(cls, payload: dict[str, Any]) -> "UsageState":
|
||||||
|
events = deque(
|
||||||
|
UsageEvent(
|
||||||
|
timestamp=datetime.fromisoformat(event["timestamp"]),
|
||||||
|
input_tokens=event["input_tokens"],
|
||||||
|
output_tokens=event["output_tokens"],
|
||||||
|
)
|
||||||
|
for event in payload.get("events", [])
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
events=events,
|
||||||
|
window_input_tokens=payload.get("window_input_tokens", 0),
|
||||||
|
window_output_tokens=payload.get("window_output_tokens", 0),
|
||||||
|
lifetime_input_tokens=payload.get("lifetime_input_tokens", 0),
|
||||||
|
lifetime_output_tokens=payload.get("lifetime_output_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenAllowance:
|
class TokenAllowance:
|
||||||
@ -77,13 +114,13 @@ class UsageBreakdown:
|
|||||||
def split_model_key(model: str) -> tuple[str, str]:
|
def split_model_key(model: str) -> tuple[str, str]:
|
||||||
if "/" not in model:
|
if "/" not in model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model must be formatted as '<provider>/<model_name>'"
|
f"model must be formatted as '<provider>/<model_name>': got '{model}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
provider, model_name = model.split("/", maxsplit=1)
|
provider, model_name = model.split("/", maxsplit=1)
|
||||||
if not provider or not model_name:
|
if not provider or not model_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model must include both provider and model name separated by '/'"
|
f"model must include both provider and model name separated by '/': got '{model}'"
|
||||||
)
|
)
|
||||||
return provider, model_name
|
return provider, model_name
|
||||||
|
|
||||||
@ -93,11 +130,15 @@ class UsageTracker:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
configs: dict[str, RateLimitConfig],
|
configs: dict[str, RateLimitConfig] | None = None,
|
||||||
default_config: RateLimitConfig | None = None,
|
default_config: RateLimitConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._configs = configs
|
self._configs = configs or {}
|
||||||
self._default_config = default_config
|
self._default_config = default_config or RateLimitConfig(
|
||||||
|
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
|
||||||
|
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||||
|
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||||
|
)
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -180,15 +221,14 @@ class UsageTracker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
split_model_key(model)
|
split_model_key(model)
|
||||||
key = model
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
config = self._get_config(key)
|
config = self._get_config(model)
|
||||||
if config is None:
|
if config is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
state = self.get_state(key)
|
state = self.get_state(model)
|
||||||
self._prune_expired_events(state, config, now=timestamp)
|
self._prune_expired_events(state, config, now=timestamp)
|
||||||
self.save_state(key, state)
|
self.save_state(model, state)
|
||||||
|
|
||||||
if config.max_total_tokens is None:
|
if config.max_total_tokens is None:
|
||||||
total_remaining = None
|
total_remaining = None
|
||||||
@ -205,9 +245,7 @@ class UsageTracker:
|
|||||||
if config.max_output_tokens is None:
|
if config.max_output_tokens is None:
|
||||||
output_remaining = None
|
output_remaining = None
|
||||||
else:
|
else:
|
||||||
output_remaining = (
|
output_remaining = config.max_output_tokens - state.window_output_tokens
|
||||||
config.max_output_tokens - state.window_output_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
return TokenAllowance(
|
return TokenAllowance(
|
||||||
input_tokens=clamp_non_negative(input_remaining),
|
input_tokens=clamp_non_negative(input_remaining),
|
||||||
@ -222,8 +260,8 @@ class UsageTracker:
|
|||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||||
for key, state in self.iter_state_items():
|
for model, state in self.iter_state_items():
|
||||||
prov, model_name = split_model_key(key)
|
prov, model_name = split_model_key(model)
|
||||||
if provider and provider != prov:
|
if provider and provider != prov:
|
||||||
continue
|
continue
|
||||||
if model and model != model_name:
|
if model and model != model_name:
|
||||||
@ -265,8 +303,8 @@ class UsageTracker:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def _get_config(self, key: str) -> RateLimitConfig | None:
|
def _get_config(self, model: str) -> RateLimitConfig | None:
|
||||||
return self._configs.get(key) or self._default_config
|
return self._configs.get(model) or self._default_config
|
||||||
|
|
||||||
def _prune_expired_events(
|
def _prune_expired_events(
|
||||||
self,
|
self,
|
||||||
@ -313,4 +351,3 @@ def clamp_non_negative(value: int | None) -> int | None:
|
|||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return 0 if value < 0 else value
|
return 0 if value < 0 else value
|
||||||
|
|
||||||
@ -31,33 +31,25 @@ def make_db_url(
|
|||||||
|
|
||||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||||
|
|
||||||
|
# Redis settings
|
||||||
|
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||||
|
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||||
|
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||||
|
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||||
|
if REDIS_PASSWORD:
|
||||||
|
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||||
|
else:
|
||||||
|
REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||||
|
|
||||||
# Broker settings
|
# Broker settings
|
||||||
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
||||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "")
|
||||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD)
|
||||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
|
||||||
CELERY_BROKER_USER = os.getenv(
|
|
||||||
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
|
||||||
)
|
|
||||||
CELERY_BROKER_PASSWORD = os.getenv(
|
|
||||||
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
|
||||||
if not CELERY_BROKER_HOST:
|
|
||||||
if CELERY_BROKER_TYPE == "amqp":
|
|
||||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
|
||||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
|
||||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
|
||||||
elif CELERY_BROKER_TYPE == "redis":
|
|
||||||
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
|
||||||
|
|
||||||
|
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}"
|
||||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||||
|
|
||||||
|
|
||||||
# File storage settings
|
# File storage settings
|
||||||
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
||||||
EBOOK_STORAGE_DIR = pathlib.Path(
|
EBOOK_STORAGE_DIR = pathlib.Path(
|
||||||
@ -81,9 +73,14 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
|
|||||||
NOTES_STORAGE_DIR = pathlib.Path(
|
NOTES_STORAGE_DIR = pathlib.Path(
|
||||||
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
|
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
|
||||||
)
|
)
|
||||||
|
PRIVATE_DIRS = [
|
||||||
|
EMAIL_STORAGE_DIR,
|
||||||
|
NOTES_STORAGE_DIR,
|
||||||
|
PHOTO_STORAGE_DIR,
|
||||||
|
CHUNK_STORAGE_DIR,
|
||||||
|
]
|
||||||
|
|
||||||
storage_dirs = [
|
storage_dirs = [
|
||||||
FILE_STORAGE_DIR,
|
|
||||||
EBOOK_STORAGE_DIR,
|
EBOOK_STORAGE_DIR,
|
||||||
EMAIL_STORAGE_DIR,
|
EMAIL_STORAGE_DIR,
|
||||||
CHUNK_STORAGE_DIR,
|
CHUNK_STORAGE_DIR,
|
||||||
@ -148,6 +145,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
|
|||||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||||
|
|
||||||
|
DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int(
|
||||||
|
os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30)
|
||||||
|
)
|
||||||
|
DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int(
|
||||||
|
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000)
|
||||||
|
)
|
||||||
|
DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int(
|
||||||
|
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000)
|
||||||
|
)
|
||||||
|
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
|
||||||
|
|
||||||
|
|
||||||
# Search settings
|
# Search settings
|
||||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||||
@ -193,3 +202,14 @@ DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
|
|||||||
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
|
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
|
||||||
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
|
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
|
||||||
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
||||||
|
|
||||||
|
|
||||||
|
# S3 Backup settings
|
||||||
|
S3_BACKUP_BUCKET = os.getenv("S3_BACKUP_BUCKET", "equistamp-memory-backup")
|
||||||
|
S3_BACKUP_PREFIX = os.getenv("S3_BACKUP_PREFIX", "Daniel")
|
||||||
|
S3_BACKUP_REGION = os.getenv("S3_BACKUP_REGION", "eu-central-1")
|
||||||
|
BACKUP_ENCRYPTION_KEY = os.getenv("BACKUP_ENCRYPTION_KEY", "")
|
||||||
|
S3_BACKUP_ENABLED = boolean_env("S3_BACKUP_ENABLED", bool(BACKUP_ENCRYPTION_KEY))
|
||||||
|
S3_BACKUP_INTERVAL = int(
|
||||||
|
os.getenv("S3_BACKUP_INTERVAL", 60 * 60 * 24)
|
||||||
|
) # Daily by default
|
||||||
|
|||||||
@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.discord.collector import MessageCollector
|
|
||||||
from memory.common.db.models.users import BotUser
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models.users import DiscordBotUser
|
||||||
|
from memory.discord.collector import MessageCollector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -34,6 +35,16 @@ class SendChannelRequest(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class TypingDMRequest(BaseModel):
|
||||||
|
bot_id: int
|
||||||
|
user: int | str
|
||||||
|
|
||||||
|
|
||||||
|
class TypingChannelRequest(BaseModel):
|
||||||
|
bot_id: int
|
||||||
|
channel_name: str
|
||||||
|
|
||||||
|
|
||||||
class Collector:
|
class Collector:
|
||||||
collector: MessageCollector
|
collector: MessageCollector
|
||||||
collector_task: asyncio.Task
|
collector_task: asyncio.Task
|
||||||
@ -41,37 +52,25 @@ class Collector:
|
|||||||
bot_token: str
|
bot_token: str
|
||||||
bot_name: str
|
bot_name: str
|
||||||
|
|
||||||
def __init__(self, collector: MessageCollector, bot: BotUser):
|
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.collector_task = asyncio.create_task(collector.start(bot.api_key))
|
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||||
self.bot_id = bot.id
|
self.bot_id = cast(int, bot.id)
|
||||||
self.bot_token = bot.api_key
|
self.bot_token = str(bot.api_key)
|
||||||
self.bot_name = bot.name
|
self.bot_name = str(bot.name)
|
||||||
|
|
||||||
|
|
||||||
# Application state
|
|
||||||
class AppState:
|
|
||||||
def __init__(self):
|
|
||||||
self.collector: MessageCollector | None = None
|
|
||||||
self.collector_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
|
|
||||||
app_state = AppState()
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Manage Discord collector lifecycle"""
|
"""Manage Discord collector lifecycle"""
|
||||||
if not settings.DISCORD_BOT_TOKEN:
|
|
||||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
|
||||||
return
|
|
||||||
|
|
||||||
def make_collector(bot: BotUser):
|
def make_collector(bot: DiscordBotUser):
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
return Collector(collector=collector, bot=bot)
|
return Collector(collector=collector, bot=bot)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()}
|
bots = session.query(DiscordBotUser).all()
|
||||||
|
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||||
|
|
||||||
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
||||||
|
|
||||||
@ -120,6 +119,32 @@ async def send_dm_endpoint(request: SendDMRequest):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/typing/dm")
|
||||||
|
async def trigger_dm_typing(request: TypingDMRequest):
|
||||||
|
"""Trigger a typing indicator for a DM via the collector"""
|
||||||
|
collector = app.bots.get(request.bot_id)
|
||||||
|
if not collector:
|
||||||
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await collector.collector.trigger_typing_dm(request.user)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to trigger DM typing: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to trigger typing for {request.user}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"user": request.user,
|
||||||
|
"message": f"Typing triggered for {request.user}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/send_channel")
|
@app.post("/send_channel")
|
||||||
async def send_channel_endpoint(request: SendChannelRequest):
|
async def send_channel_endpoint(request: SendChannelRequest):
|
||||||
"""Send a message to a channel via the collector's Discord client"""
|
"""Send a message to a channel via the collector's Discord client"""
|
||||||
@ -131,6 +156,9 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
|||||||
success = await collector.collector.send_to_channel(
|
success = await collector.collector.send_to_channel(
|
||||||
request.channel_name, request.message
|
request.channel_name, request.message
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send channel message: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
return {
|
return {
|
||||||
@ -138,16 +166,38 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
|||||||
"message": f"Message sent to channel {request.channel_name}",
|
"message": f"Message sent to channel {request.channel_name}",
|
||||||
"channel": request.channel_name,
|
"channel": request.channel_name,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Failed to send message to channel {request.channel_name}",
|
detail=f"Failed to send message to channel {request.channel_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/typing/channel")
|
||||||
|
async def trigger_channel_typing(request: TypingChannelRequest):
|
||||||
|
"""Trigger a typing indicator for a channel via the collector"""
|
||||||
|
collector = app.bots.get(request.bot_id)
|
||||||
|
if not collector:
|
||||||
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await collector.collector.trigger_typing_channel(request.channel_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send channel message: {e}")
|
logger.error(f"Failed to trigger channel typing: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to trigger typing for channel {request.channel_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"channel": request.channel_name,
|
||||||
|
"message": f"Typing triggered for channel {request.channel_name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
@ -155,9 +205,8 @@ async def health_check():
|
|||||||
if not app.bots:
|
if not app.bots:
|
||||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
collector = app_state.collector
|
|
||||||
return {
|
return {
|
||||||
collector.bot_name: {
|
bot.bot_name: {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"connected": not bot.collector.is_closed(),
|
"connected": not bot.collector.is_closed(),
|
||||||
"user": str(bot.collector.user) if bot.collector.user else None,
|
"user": str(bot.collector.user) if bot.collector.user else None,
|
||||||
|
|||||||
@ -203,7 +203,7 @@ class MessageCollector(commands.Bot):
|
|||||||
async def setup_hook(self):
|
async def setup_hook(self):
|
||||||
"""Register slash commands when the bot is ready."""
|
"""Register slash commands when the bot is ready."""
|
||||||
|
|
||||||
register_slash_commands(self)
|
register_slash_commands(self, name=self.user.name)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""Called when bot connects to Discord"""
|
"""Called when bot connects to Discord"""
|
||||||
@ -381,6 +381,27 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def trigger_typing_dm(self, user_identifier: int | str) -> bool:
|
||||||
|
"""Trigger typing indicator in a DM"""
|
||||||
|
try:
|
||||||
|
user = await self.get_user(user_identifier)
|
||||||
|
if not user:
|
||||||
|
logger.error(f"User {user_identifier} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
channel = user.dm_channel or await user.create_dm()
|
||||||
|
if not channel:
|
||||||
|
logger.error(f"DM channel not available for {user_identifier}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async with channel.typing():
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
||||||
"""Send a message to a channel by name across all guilds"""
|
"""Send a message to a channel by name across all guilds"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
@ -400,23 +421,21 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def trigger_typing_channel(self, channel_name: str) -> bool:
|
||||||
async def run_collector():
|
"""Trigger typing indicator in a channel"""
|
||||||
"""Run the Discord message collector"""
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
if not settings.DISCORD_BOT_TOKEN:
|
return False
|
||||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
|
||||||
return
|
|
||||||
|
|
||||||
collector = MessageCollector()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await collector.start(settings.DISCORD_BOT_TOKEN)
|
channel = await self.get_channel_by_name(channel_name)
|
||||||
|
if not channel:
|
||||||
|
logger.error(f"Channel {channel_name} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async with channel.typing():
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Discord collector failed: {e}")
|
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
||||||
raise
|
return False
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(run_collector())
|
|
||||||
|
|||||||
@ -41,8 +41,13 @@ class CommandContext:
|
|||||||
CommandHandler = Callable[..., CommandResponse]
|
CommandHandler = Callable[..., CommandResponse]
|
||||||
|
|
||||||
|
|
||||||
def register_slash_commands(bot: discord.Client) -> None:
|
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||||
"""Register the collector slash commands on the provided bot."""
|
"""Register the collector slash commands on the provided bot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot: Discord bot client
|
||||||
|
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
|
||||||
|
"""
|
||||||
|
|
||||||
if getattr(bot, "_memory_commands_registered", False):
|
if getattr(bot, "_memory_commands_registered", False):
|
||||||
return
|
return
|
||||||
@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
|
|
||||||
tree = bot.tree
|
tree = bot.tree
|
||||||
|
|
||||||
@tree.command(name="memory_prompt", description="Show the current system prompt")
|
@tree.command(
|
||||||
|
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||||
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
scope="Which configuration to inspect",
|
scope="Which configuration to inspect",
|
||||||
user="Target user when the scope is 'user'",
|
user="Target user when the scope is 'user'",
|
||||||
)
|
)
|
||||||
async def prompt_command(
|
async def show_prompt_command(
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
scope: ScopeLiteral,
|
scope: ScopeLiteral,
|
||||||
user: discord.User | None = None,
|
user: discord.User | None = None,
|
||||||
@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tree.command(
|
@tree.command(
|
||||||
name="memory_chattiness",
|
name=f"{name}_set_prompt",
|
||||||
description="Show or update the chattiness threshold for the target",
|
description="Set the system prompt for the target",
|
||||||
|
)
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
scope="Which configuration to modify",
|
||||||
|
prompt="The system prompt to set",
|
||||||
|
user="Target user when the scope is 'user'",
|
||||||
|
)
|
||||||
|
async def set_prompt_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
prompt: str,
|
||||||
|
user: discord.User | None = None,
|
||||||
|
) -> None:
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_set_prompt,
|
||||||
|
target_user=user,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tree.command(
|
||||||
|
name=f"{name}_chattiness",
|
||||||
|
description="Show or update the chattiness for the target",
|
||||||
)
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
scope="Which configuration to inspect",
|
scope="Which configuration to inspect",
|
||||||
value="Optional new threshold value between 0 and 100",
|
value="Optional new chattiness value between 0 and 100",
|
||||||
user="Target user when the scope is 'user'",
|
user="Target user when the scope is 'user'",
|
||||||
)
|
)
|
||||||
async def chattiness_command(
|
async def chattiness_command(
|
||||||
@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tree.command(
|
@tree.command(
|
||||||
name="memory_ignore",
|
name=f"{name}_ignore",
|
||||||
description="Toggle whether the bot should ignore messages for the target",
|
description="Toggle whether the bot should ignore messages for the target",
|
||||||
)
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
ignore_enabled=enabled,
|
ignore_enabled=enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tree.command(name="memory_summary", description="Show the stored summary for the target")
|
@tree.command(
|
||||||
|
name=f"{name}_show_summary",
|
||||||
|
description="Show the stored summary for the target",
|
||||||
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
scope="Which configuration to inspect",
|
scope="Which configuration to inspect",
|
||||||
user="Target user when the scope is 'user'",
|
user="Target user when the scope is 'user'",
|
||||||
@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_set_prompt(
|
||||||
|
context: CommandContext,
|
||||||
|
*,
|
||||||
|
prompt: str,
|
||||||
|
) -> CommandResponse:
|
||||||
|
setattr(context.target, "system_prompt", prompt)
|
||||||
|
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"Updated system prompt for {context.display_name}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_chattiness(
|
def handle_chattiness(
|
||||||
context: CommandContext,
|
context: CommandContext,
|
||||||
*,
|
*,
|
||||||
@ -347,20 +392,22 @@ def handle_chattiness(
|
|||||||
if value is None:
|
if value is None:
|
||||||
return CommandResponse(
|
return CommandResponse(
|
||||||
content=(
|
content=(
|
||||||
f"Chattiness threshold for {context.display_name}: "
|
f"Chattiness for {context.display_name}: "
|
||||||
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not 0 <= value <= 100:
|
if not 0 <= value <= 100:
|
||||||
raise CommandError("Chattiness threshold must be between 0 and 100.")
|
raise CommandError("Chattiness must be between 0 and 100.")
|
||||||
|
|
||||||
setattr(model, "chattiness_threshold", value)
|
setattr(model, "chattiness_threshold", value)
|
||||||
|
|
||||||
return CommandResponse(
|
return CommandResponse(
|
||||||
content=(
|
content=(
|
||||||
f"Updated chattiness threshold for {context.display_name} "
|
f"Updated chattiness for {context.display_name} to {value}."
|
||||||
f"to {value}."
|
"\n"
|
||||||
|
"This can be treated as how much you want the bot to pipe up by itself, as a percentage, "
|
||||||
|
"where 0 is never and 100 is always."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from memory.common.celery_app import (
|
|||||||
TRACK_GIT_CHANGES,
|
TRACK_GIT_CHANGES,
|
||||||
SYNC_LESSWRONG,
|
SYNC_LESSWRONG,
|
||||||
RUN_SCHEDULED_CALLS,
|
RUN_SCHEDULED_CALLS,
|
||||||
|
BACKUP_ALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -48,4 +49,8 @@ app.conf.beat_schedule = {
|
|||||||
"task": RUN_SCHEDULED_CALLS,
|
"task": RUN_SCHEDULED_CALLS,
|
||||||
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
|
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
|
||||||
},
|
},
|
||||||
|
"backup-all": {
|
||||||
|
"task": BACKUP_ALL,
|
||||||
|
"schedule": settings.S3_BACKUP_INTERVAL,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,11 +3,12 @@ Import sub-modules so Celery can register their @app.task decorators.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from memory.workers.tasks import (
|
from memory.workers.tasks import (
|
||||||
email,
|
backup,
|
||||||
comic,
|
|
||||||
blogs,
|
blogs,
|
||||||
|
comic,
|
||||||
discord,
|
discord,
|
||||||
ebook,
|
ebook,
|
||||||
|
email,
|
||||||
forums,
|
forums,
|
||||||
maintenance,
|
maintenance,
|
||||||
notes,
|
notes,
|
||||||
@ -15,8 +16,8 @@ from memory.workers.tasks import (
|
|||||||
scheduled_calls,
|
scheduled_calls,
|
||||||
) # noqa
|
) # noqa
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"backup",
|
||||||
"email",
|
"email",
|
||||||
"comic",
|
"comic",
|
||||||
"blogs",
|
"blogs",
|
||||||
|
|||||||
152
src/memory/workers/tasks/backup.py
Normal file
152
src/memory/workers/tasks/backup.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
"""S3 backup tasks for memory files."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.celery_app import app, BACKUP_PATH, BACKUP_ALL
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cipher() -> Fernet:
|
||||||
|
"""Create Fernet cipher from password in settings."""
|
||||||
|
if not settings.BACKUP_ENCRYPTION_KEY:
|
||||||
|
raise ValueError("BACKUP_ENCRYPTION_KEY not set in environment")
|
||||||
|
|
||||||
|
# Derive key from password using SHA256
|
||||||
|
key_bytes = hashlib.sha256(settings.BACKUP_ENCRYPTION_KEY.encode()).digest()
|
||||||
|
key = base64.urlsafe_b64encode(key_bytes)
|
||||||
|
return Fernet(key)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tarball(directory: Path) -> bytes:
|
||||||
|
"""Create a gzipped tarball of a directory in memory."""
|
||||||
|
if not directory.exists():
|
||||||
|
logger.warning(f"Directory does not exist: {directory}")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
tar_buffer = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar:
|
||||||
|
tar.add(directory, arcname=directory.name)
|
||||||
|
|
||||||
|
tar_buffer.seek(0)
|
||||||
|
return tar_buffer.read()
|
||||||
|
|
||||||
|
|
||||||
|
def sync_unencrypted_directory(path: Path) -> dict:
|
||||||
|
"""Sync an unencrypted directory to S3 using aws s3 sync."""
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning(f"Directory does not exist: {path}")
|
||||||
|
return {"synced": False, "reason": "directory_not_found"}
|
||||||
|
|
||||||
|
s3_uri = f"s3://{settings.S3_BACKUP_BUCKET}/{settings.S3_BACKUP_PREFIX}/{path.name}"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"aws",
|
||||||
|
"s3",
|
||||||
|
"sync",
|
||||||
|
str(path),
|
||||||
|
s3_uri,
|
||||||
|
"--delete",
|
||||||
|
"--region",
|
||||||
|
settings.S3_BACKUP_REGION,
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
logger.info(f"Synced {path} to {s3_uri}")
|
||||||
|
logger.debug(f"Output: {result.stdout}")
|
||||||
|
return {"synced": True, "directory": path, "s3_uri": s3_uri}
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Failed to sync {path}: {e.stderr}")
|
||||||
|
return {"synced": False, "directory": path, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def backup_encrypted_directory(path: Path) -> dict:
|
||||||
|
"""Create encrypted tarball of directory and upload to S3."""
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning(f"Directory does not exist: {path}")
|
||||||
|
return {"uploaded": False, "reason": "directory_not_found"}
|
||||||
|
|
||||||
|
# Create tarball
|
||||||
|
logger.info(f"Creating tarball of {path}...")
|
||||||
|
tarball_bytes = create_tarball(path)
|
||||||
|
|
||||||
|
if not tarball_bytes:
|
||||||
|
logger.warning(f"Empty tarball for {path}, skipping")
|
||||||
|
return {"uploaded": False, "reason": "empty_directory"}
|
||||||
|
|
||||||
|
# Encrypt
|
||||||
|
logger.info(f"Encrypting {path} ({len(tarball_bytes)} bytes)...")
|
||||||
|
cipher = get_cipher()
|
||||||
|
encrypted_bytes = cipher.encrypt(tarball_bytes)
|
||||||
|
|
||||||
|
# Upload to S3
|
||||||
|
s3_client = boto3.client("s3", region_name=settings.S3_BACKUP_REGION)
|
||||||
|
s3_key = f"{settings.S3_BACKUP_PREFIX}/{path.name}.tar.gz.enc"
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Uploading encrypted {path} to s3://{settings.S3_BACKUP_BUCKET}/{s3_key}"
|
||||||
|
)
|
||||||
|
s3_client.put_object(
|
||||||
|
Bucket=settings.S3_BACKUP_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
Body=encrypted_bytes,
|
||||||
|
ServerSideEncryption="AES256",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"uploaded": True,
|
||||||
|
"directory": path,
|
||||||
|
"size_bytes": len(encrypted_bytes),
|
||||||
|
"s3_key": s3_key,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to upload {path}: {e}")
|
||||||
|
return {"uploaded": False, "directory": path, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=BACKUP_PATH)
|
||||||
|
def backup_to_s3(path: Path | str):
|
||||||
|
"""Backup a specific directory to S3."""
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning(f"Directory does not exist: {path}")
|
||||||
|
return {"uploaded": False, "reason": "directory_not_found"}
|
||||||
|
|
||||||
|
if path in settings.PRIVATE_DIRS:
|
||||||
|
return backup_encrypted_directory(path)
|
||||||
|
return sync_unencrypted_directory(path)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=BACKUP_ALL)
|
||||||
|
def backup_all_to_s3():
|
||||||
|
"""Main backup task that syncs unencrypted dirs and uploads encrypted dirs."""
|
||||||
|
if not settings.S3_BACKUP_ENABLED:
|
||||||
|
logger.info("S3 backup is disabled")
|
||||||
|
return {"status": "disabled"}
|
||||||
|
|
||||||
|
logger.info("Starting S3 backup...")
|
||||||
|
|
||||||
|
for dir_name in settings.storage_dirs:
|
||||||
|
backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"Started backup for {len(settings.storage_dirs)} directories",
|
||||||
|
}
|
||||||
@ -7,7 +7,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import exc as sqlalchemy_exc
|
from sqlalchemy import exc as sqlalchemy_exc
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
@ -56,8 +56,15 @@ def call_llm(
|
|||||||
message: DiscordMessage,
|
message: DiscordMessage,
|
||||||
model: str,
|
model: str,
|
||||||
msgs: list[str] = [],
|
msgs: list[str] = [],
|
||||||
allowed_tools: list[str] = [],
|
allowed_tools: list[str] | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
|
provider = create_provider(model=model)
|
||||||
|
if provider.usage_tracker.is_rate_limited(model):
|
||||||
|
logger.error(
|
||||||
|
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
tools = make_discord_tools(
|
tools = make_discord_tools(
|
||||||
message.recipient_user.system_user,
|
message.recipient_user.system_user,
|
||||||
message.from_user,
|
message.from_user,
|
||||||
@ -67,13 +74,13 @@ def call_llm(
|
|||||||
tools = {
|
tools = {
|
||||||
name: tool
|
name: tool
|
||||||
for name, tool in tools.items()
|
for name, tool in tools.items()
|
||||||
if message.tool_allowed(name) and name in allowed_tools
|
if message.tool_allowed(name)
|
||||||
|
and (allowed_tools is None or name in allowed_tools)
|
||||||
}
|
}
|
||||||
system_prompt = message.system_prompt or ""
|
system_prompt = message.system_prompt or ""
|
||||||
system_prompt += comm_channel_prompt(
|
system_prompt += comm_channel_prompt(
|
||||||
session, message.recipient_user, message.channel
|
session, message.recipient_user, message.channel
|
||||||
)
|
)
|
||||||
provider = create_provider(model=model)
|
|
||||||
messages = previous_messages(
|
messages = previous_messages(
|
||||||
session,
|
session,
|
||||||
message.recipient_user and message.recipient_user.id,
|
message.recipient_user and message.recipient_user.id,
|
||||||
@ -127,10 +134,32 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
return int(res.group(1)) > message.chattiness_threshold
|
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
||||||
|
return False
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if not (bot_id := _resolve_bot_id(message)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if message.channel and message.channel.server:
|
||||||
|
discord.trigger_typing_channel(bot_id, message.channel.name)
|
||||||
|
else:
|
||||||
|
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_bot_id(discord_message: DiscordMessage) -> int | None:
|
||||||
|
recipient = discord_message.recipient_user
|
||||||
|
if not recipient:
|
||||||
|
return None
|
||||||
|
|
||||||
|
system_user = recipient.system_user
|
||||||
|
if not system_user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return getattr(system_user, "id", None)
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||||
@safe_task_execution
|
@safe_task_execution
|
||||||
@ -152,14 +181,33 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
bot_id = _resolve_bot_id(discord_message)
|
||||||
|
if not bot_id:
|
||||||
|
logger.warning(
|
||||||
|
"No associated Discord bot user for message %s; skipping send",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "processed",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to generate Discord response")
|
||||||
|
|
||||||
|
print("response:", response)
|
||||||
if not response:
|
if not response:
|
||||||
pass
|
return {
|
||||||
elif discord_message.channel.server:
|
"status": "processed",
|
||||||
discord.send_to_channel(discord_message.channel.name, response)
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if discord_message.channel.server:
|
||||||
|
discord.send_to_channel(bot_id, discord_message.channel.name, response)
|
||||||
else:
|
else:
|
||||||
discord.send_dm(discord_message.from_user.username, response)
|
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "processed",
|
||||||
|
|||||||
@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
|||||||
if len(message) > 1900: # Leave some buffer
|
if len(message) > 1900: # Leave some buffer
|
||||||
message = message[:1900] + "\n\n... (response truncated)"
|
message = message[:1900] + "\n\n... (response truncated)"
|
||||||
|
|
||||||
|
bot_id_value = scheduled_call.user_id
|
||||||
|
if bot_id_value is None:
|
||||||
|
logger.warning(
|
||||||
|
"Scheduled call %s has no associated bot user; skipping Discord send",
|
||||||
|
scheduled_call.id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
bot_id = cast(int, bot_id_value)
|
||||||
|
|
||||||
if discord_user := scheduled_call.discord_user:
|
if discord_user := scheduled_call.discord_user:
|
||||||
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
||||||
discord.send_dm(discord_user.username, message)
|
discord.send_dm(bot_id, discord_user.username, message)
|
||||||
elif discord_channel := scheduled_call.discord_channel:
|
elif discord_channel := scheduled_call.discord_channel:
|
||||||
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
||||||
discord.broadcast_message(discord_channel.name, message)
|
discord.broadcast_message(bot_id, discord_channel.name, message)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
||||||
|
|||||||
@ -1,7 +1,24 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
import redis # noqa: F401 # pragma: no cover - optional test dependency
|
||||||
|
except ModuleNotFoundError: # pragma: no cover - import guard for test envs
|
||||||
|
import sys
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
class _RedisStub(SimpleNamespace):
|
||||||
|
class Redis: # type: ignore[no-redef]
|
||||||
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"The 'redis' package is required to use RedisUsageTracker"
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.modules.setdefault("redis", _RedisStub())
|
||||||
|
|
||||||
|
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
|
||||||
from memory.common.llms.usage_tracker import (
|
from memory.common.llms.usage_tracker import (
|
||||||
InMemoryUsageTracker,
|
InMemoryUsageTracker,
|
||||||
RateLimitConfig,
|
RateLimitConfig,
|
||||||
@ -9,6 +26,24 @@ from memory.common.llms.usage_tracker import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRedis:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._store: dict[str, str] = {}
|
||||||
|
|
||||||
|
def get(self, key: str) -> str | None:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
def set(self, key: str, value: str) -> None:
|
||||||
|
self._store[key] = value
|
||||||
|
|
||||||
|
def scan_iter(self, match: str) -> Iterable[str]:
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
|
||||||
|
for key in list(self._store.keys()):
|
||||||
|
if fnmatch(key, match):
|
||||||
|
yield key
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tracker() -> InMemoryUsageTracker:
|
def tracker() -> InMemoryUsageTracker:
|
||||||
config = RateLimitConfig(
|
config = RateLimitConfig(
|
||||||
@ -25,6 +60,23 @@ def tracker() -> InMemoryUsageTracker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def redis_tracker() -> RedisUsageTracker:
|
||||||
|
config = RateLimitConfig(
|
||||||
|
window=timedelta(minutes=1),
|
||||||
|
max_input_tokens=1_000,
|
||||||
|
max_output_tokens=2_000,
|
||||||
|
max_total_tokens=2_500,
|
||||||
|
)
|
||||||
|
return RedisUsageTracker(
|
||||||
|
{
|
||||||
|
"anthropic/claude-3": config,
|
||||||
|
"anthropic/haiku": config,
|
||||||
|
},
|
||||||
|
redis_client=FakeRedis(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"window, kwargs",
|
"window, kwargs",
|
||||||
[
|
[
|
||||||
@ -139,6 +191,22 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
|||||||
assert tracker.is_rate_limited("openai/gpt-4o")
|
assert tracker.is_rate_limited("openai/gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
|
||||||
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
|
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||||
|
|
||||||
|
allowance = redis_tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
|
||||||
|
assert allowance is not None
|
||||||
|
assert allowance.input_tokens == 900
|
||||||
|
|
||||||
|
breakdown = redis_tracker.get_usage_breakdown()
|
||||||
|
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
|
||||||
|
|
||||||
|
items = dict(redis_tracker.iter_state_items())
|
||||||
|
assert set(items.keys()) == {"anthropic/claude-3", "anthropic/haiku"}
|
||||||
|
|
||||||
|
|
||||||
def test_usage_tracker_base_not_instantiable() -> None:
|
def test_usage_tracker_base_not_instantiable() -> None:
|
||||||
class DummyTracker(UsageTracker):
|
class DummyTracker(UsageTracker):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import requests
|
|||||||
|
|
||||||
from memory.common import discord
|
from memory.common import discord
|
||||||
|
|
||||||
|
BOT_ID = 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_api_url():
|
def mock_api_url():
|
||||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_dm",
|
"http://localhost:8000/send_dm",
|
||||||
json={"user": "user123", "message": "Hello!"},
|
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
|||||||
"""Test DM sending when request raises exception"""
|
"""Test DM sending when request raises exception"""
|
||||||
mock_post.side_effect = requests.RequestException("Network error")
|
mock_post.side_effect = requests.RequestException("Network error")
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_channel",
|
"http://localhost:8000/send_channel",
|
||||||
json={"channel_name": "general", "message": "Announcement!"},
|
json={
|
||||||
|
"bot_id": BOT_ID,
|
||||||
|
"channel_name": "general",
|
||||||
|
"message": "Announcement!",
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
"""Test channel message broadcast with exception"""
|
"""Test channel message broadcast with exception"""
|
||||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||||
"""Test health check when collector is healthy"""
|
"""Test health check when collector is healthy"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "healthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
|||||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||||
"""Test health check when collector returns unhealthy status"""
|
"""Test health check when collector returns unhealthy status"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "unhealthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
|||||||
"""Test health check when request fails"""
|
"""Test health check when request fails"""
|
||||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
|||||||
"""Test sending error message to error channel"""
|
"""Test sending error message to error channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_error_message("Something broke")
|
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
|||||||
"""Test sending activity message to activity channel"""
|
"""Test sending activity message to activity channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_activity_message("User logged in")
|
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "activity", "User logged in"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
|||||||
"""Test sending discovery message to discovery channel"""
|
"""Test sending discovery message to discovery channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_discovery_message("Found interesting pattern")
|
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "discoveries", "Found interesting pattern"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
|||||||
"""Test sending chat message to chat channel"""
|
"""Test sending chat message to chat channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_chat_message("Hello from bot")
|
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_basic(mock_send_error):
|
def test_notify_task_failure_basic(mock_send_error):
|
||||||
"""Test basic task failure notification"""
|
"""Test basic task failure notification"""
|
||||||
discord.notify_task_failure("test_task", "Something went wrong")
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
message = mock_send_error.call_args[0][0]
|
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||||
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "🚨 **Task Failed: test_task**" in message
|
assert "🚨 **Task Failed: test_task**" in message
|
||||||
assert "**Error:** Something went wrong" in message
|
assert "**Error:** Something went wrong" in message
|
||||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
|||||||
"Error occurred",
|
"Error occurred",
|
||||||
task_args=("arg1", 42),
|
task_args=("arg1", 42),
|
||||||
task_kwargs={"key": "value", "number": 123},
|
task_kwargs={"key": "value", "number": 123},
|
||||||
|
bot_id=BOT_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Args:** `('arg1', 42)" in message
|
assert "**Args:** `('arg1', 42)" in message
|
||||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
|||||||
"""Test task failure notification with traceback"""
|
"""Test task failure notification with traceback"""
|
||||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Traceback:**" in message
|
assert "**Traceback:**" in message
|
||||||
assert "Exception: test" in message
|
assert "Exception: test" in message
|
||||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
|||||||
"""Test that long error messages are truncated"""
|
"""Test that long error messages are truncated"""
|
||||||
long_error = "x" * 600
|
long_error = "x" * 600
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", long_error)
|
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||||
assert "**Error:** " + long_error[:500] in message
|
assert "**Error:** " + long_error[:500] in message
|
||||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
|||||||
"""Test that long tracebacks are truncated"""
|
"""Test that long tracebacks are truncated"""
|
||||||
long_traceback = "x" * 1000
|
long_traceback = "x" * 1000
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Traceback should show last 800 chars
|
# Traceback should show last 800 chars
|
||||||
assert long_traceback[-800:] in message
|
assert long_traceback[-800:] in message
|
||||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
|||||||
"""Test that long task arguments are truncated"""
|
"""Test that long task arguments are truncated"""
|
||||||
long_args = ("x" * 300,)
|
long_args = ("x" * 300,)
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Args should be truncated to 200 chars
|
# Args should be truncated to 200 chars
|
||||||
assert (
|
assert (
|
||||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
"""Test that long task kwargs are truncated"""
|
"""Test that long task kwargs are truncated"""
|
||||||
long_kwargs = {"key": "x" * 300}
|
long_kwargs = {"key": "x" * 300}
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Kwargs should be truncated to 200 chars
|
# Kwargs should be truncated to 200 chars
|
||||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
def test_notify_task_failure_disabled(mock_send_error):
|
def test_notify_task_failure_disabled(mock_send_error):
|
||||||
"""Test that notifications are not sent when disabled"""
|
"""Test that notifications are not sent when disabled"""
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_not_called()
|
mock_send_error.assert_not_called()
|
||||||
|
|
||||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
|||||||
mock_send_error.side_effect = Exception("Failed to send")
|
mock_send_error.side_effect = Exception("Failed to send")
|
||||||
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
|
|
||||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
|||||||
):
|
):
|
||||||
"""Test that convenience functions use the correct channel settings"""
|
"""Test that convenience functions use the correct channel settings"""
|
||||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||||
function(message)
|
function(BOT_ID, message)
|
||||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||||
result = discord.send_dm("user123", message_with_special_chars)
|
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == message_with_special_chars
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
long_message = "A" * 2000
|
long_message = "A" * 2000
|
||||||
result = discord.broadcast_message("general", long_message)
|
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == long_message
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == long_message
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("requests.get")
|
||||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import requests
|
|||||||
|
|
||||||
from memory.common import discord
|
from memory.common import discord
|
||||||
|
|
||||||
|
BOT_ID = 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_api_url():
|
def mock_api_url():
|
||||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_dm",
|
"http://localhost:8000/send_dm",
|
||||||
json={"user": "user123", "message": "Hello!"},
|
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
|||||||
"""Test DM sending when request raises exception"""
|
"""Test DM sending when request raises exception"""
|
||||||
mock_post.side_effect = requests.RequestException("Network error")
|
mock_post.side_effect = requests.RequestException("Network error")
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_channel",
|
"http://localhost:8000/send_channel",
|
||||||
json={"channel_name": "general", "message": "Announcement!"},
|
json={
|
||||||
|
"bot_id": BOT_ID,
|
||||||
|
"channel_name": "general",
|
||||||
|
"message": "Announcement!",
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
"""Test channel message broadcast with exception"""
|
"""Test channel message broadcast with exception"""
|
||||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||||
"""Test health check when collector is healthy"""
|
"""Test health check when collector is healthy"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "healthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
|||||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||||
"""Test health check when collector returns unhealthy status"""
|
"""Test health check when collector returns unhealthy status"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "unhealthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
|||||||
"""Test health check when request fails"""
|
"""Test health check when request fails"""
|
||||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
|||||||
"""Test sending error message to error channel"""
|
"""Test sending error message to error channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_error_message("Something broke")
|
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
|||||||
"""Test sending activity message to activity channel"""
|
"""Test sending activity message to activity channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_activity_message("User logged in")
|
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "activity", "User logged in"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
|||||||
"""Test sending discovery message to discovery channel"""
|
"""Test sending discovery message to discovery channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_discovery_message("Found interesting pattern")
|
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "discoveries", "Found interesting pattern"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
|||||||
"""Test sending chat message to chat channel"""
|
"""Test sending chat message to chat channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_chat_message("Hello from bot")
|
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_basic(mock_send_error):
|
def test_notify_task_failure_basic(mock_send_error):
|
||||||
"""Test basic task failure notification"""
|
"""Test basic task failure notification"""
|
||||||
discord.notify_task_failure("test_task", "Something went wrong")
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
message = mock_send_error.call_args[0][0]
|
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||||
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "🚨 **Task Failed: test_task**" in message
|
assert "🚨 **Task Failed: test_task**" in message
|
||||||
assert "**Error:** Something went wrong" in message
|
assert "**Error:** Something went wrong" in message
|
||||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
|||||||
"Error occurred",
|
"Error occurred",
|
||||||
task_args=("arg1", 42),
|
task_args=("arg1", 42),
|
||||||
task_kwargs={"key": "value", "number": 123},
|
task_kwargs={"key": "value", "number": 123},
|
||||||
|
bot_id=BOT_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Args:** `('arg1', 42)" in message
|
assert "**Args:** `('arg1', 42)" in message
|
||||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
|||||||
"""Test task failure notification with traceback"""
|
"""Test task failure notification with traceback"""
|
||||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Traceback:**" in message
|
assert "**Traceback:**" in message
|
||||||
assert "Exception: test" in message
|
assert "Exception: test" in message
|
||||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
|||||||
"""Test that long error messages are truncated"""
|
"""Test that long error messages are truncated"""
|
||||||
long_error = "x" * 600
|
long_error = "x" * 600
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", long_error)
|
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||||
assert "**Error:** " + long_error[:500] in message
|
assert "**Error:** " + long_error[:500] in message
|
||||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
|||||||
"""Test that long tracebacks are truncated"""
|
"""Test that long tracebacks are truncated"""
|
||||||
long_traceback = "x" * 1000
|
long_traceback = "x" * 1000
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Traceback should show last 800 chars
|
# Traceback should show last 800 chars
|
||||||
assert long_traceback[-800:] in message
|
assert long_traceback[-800:] in message
|
||||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
|||||||
"""Test that long task arguments are truncated"""
|
"""Test that long task arguments are truncated"""
|
||||||
long_args = ("x" * 300,)
|
long_args = ("x" * 300,)
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Args should be truncated to 200 chars
|
# Args should be truncated to 200 chars
|
||||||
assert (
|
assert (
|
||||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
"""Test that long task kwargs are truncated"""
|
"""Test that long task kwargs are truncated"""
|
||||||
long_kwargs = {"key": "x" * 300}
|
long_kwargs = {"key": "x" * 300}
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Kwargs should be truncated to 200 chars
|
# Kwargs should be truncated to 200 chars
|
||||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
def test_notify_task_failure_disabled(mock_send_error):
|
def test_notify_task_failure_disabled(mock_send_error):
|
||||||
"""Test that notifications are not sent when disabled"""
|
"""Test that notifications are not sent when disabled"""
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_not_called()
|
mock_send_error.assert_not_called()
|
||||||
|
|
||||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
|||||||
mock_send_error.side_effect = Exception("Failed to send")
|
mock_send_error.side_effect = Exception("Failed to send")
|
||||||
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
|
|
||||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
|||||||
):
|
):
|
||||||
"""Test that convenience functions use the correct channel settings"""
|
"""Test that convenience functions use the correct channel settings"""
|
||||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||||
function(message)
|
function(BOT_ID, message)
|
||||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||||
result = discord.send_dm("user123", message_with_special_chars)
|
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == message_with_special_chars
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
long_message = "A" * 2000
|
long_message = "A" * 2000
|
||||||
result = discord.broadcast_message("general", long_message)
|
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == long_message
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == long_message
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("requests.get")
|
||||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|||||||
393
tests/memory/workers/tasks/test_backup_tasks.py
Normal file
393
tests/memory/workers/tasks/test_backup_tasks.py
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
import io
|
||||||
|
import subprocess
|
||||||
|
import tarfile
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.workers.tasks import backup
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_files():
|
||||||
|
"""Create sample files in memory_files structure."""
|
||||||
|
base = settings.FILE_STORAGE_DIR
|
||||||
|
|
||||||
|
dirs_with_files = {
|
||||||
|
"emails": ["email1.txt", "email2.txt"],
|
||||||
|
"notes": ["note1.md", "note2.md"],
|
||||||
|
"photos": ["photo1.jpg"],
|
||||||
|
"comics": ["comic1.png", "comic2.png"],
|
||||||
|
"ebooks": ["book1.epub"],
|
||||||
|
"webpages": ["page1.html"],
|
||||||
|
}
|
||||||
|
|
||||||
|
for dir_name, filenames in dirs_with_files.items():
|
||||||
|
dir_path = base / dir_name
|
||||||
|
dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = dir_path / filename
|
||||||
|
content = f"Content of {dir_name}/{filename}\n" * 100
|
||||||
|
file_path.write_text(content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_s3_client():
|
||||||
|
"""Mock boto3 S3 client."""
|
||||||
|
with patch("boto3.client") as mock_client:
|
||||||
|
s3_mock = MagicMock()
|
||||||
|
mock_client.return_value = s3_mock
|
||||||
|
yield s3_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def backup_settings():
|
||||||
|
"""Mock backup settings."""
|
||||||
|
with (
|
||||||
|
patch.object(settings, "S3_BACKUP_ENABLED", True),
|
||||||
|
patch.object(settings, "BACKUP_ENCRYPTION_KEY", "test-password-123"),
|
||||||
|
patch.object(settings, "S3_BACKUP_BUCKET", "test-bucket"),
|
||||||
|
patch.object(settings, "S3_BACKUP_PREFIX", "test-prefix"),
|
||||||
|
patch.object(settings, "S3_BACKUP_REGION", "us-east-1"),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def get_test_path():
|
||||||
|
"""Helper to construct test paths."""
|
||||||
|
return lambda dir_name: settings.FILE_STORAGE_DIR / dir_name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data,key",
|
||||||
|
[
|
||||||
|
(b"This is a test message", "my-secret-key"),
|
||||||
|
(b"\x00\x01\x02\xff" * 10000, "another-key"),
|
||||||
|
(b"x" * 1000000, "large-data-key"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_encrypt_decrypt_roundtrip(data, key):
|
||||||
|
"""Test encryption and decryption produces original data."""
|
||||||
|
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", key):
|
||||||
|
cipher = backup.get_cipher()
|
||||||
|
encrypted = cipher.encrypt(data)
|
||||||
|
decrypted = cipher.decrypt(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == data
|
||||||
|
assert encrypted != data
|
||||||
|
|
||||||
|
|
||||||
|
def test_encrypt_decrypt_tarball(sample_files):
|
||||||
|
"""Test full tarball creation, encryption, and decryption."""
|
||||||
|
emails_dir = settings.FILE_STORAGE_DIR / "emails"
|
||||||
|
|
||||||
|
# Create tarball
|
||||||
|
tarball_bytes = backup.create_tarball(emails_dir)
|
||||||
|
assert len(tarball_bytes) > 0
|
||||||
|
|
||||||
|
# Encrypt
|
||||||
|
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "tarball-key"):
|
||||||
|
cipher = backup.get_cipher()
|
||||||
|
encrypted = cipher.encrypt(tarball_bytes)
|
||||||
|
|
||||||
|
# Decrypt
|
||||||
|
decrypted = cipher.decrypt(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == tarball_bytes
|
||||||
|
|
||||||
|
# Verify tarball can be extracted
|
||||||
|
tar_buffer = io.BytesIO(decrypted)
|
||||||
|
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||||
|
members = tar.getmembers()
|
||||||
|
assert len(members) >= 2 # At least 2 email files
|
||||||
|
|
||||||
|
# Extract and verify content
|
||||||
|
for member in members:
|
||||||
|
if member.isfile():
|
||||||
|
extracted = tar.extractfile(member)
|
||||||
|
assert extracted is not None
|
||||||
|
content = extracted.read().decode()
|
||||||
|
assert "Content of emails/" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_keys_produce_different_ciphertext():
|
||||||
|
"""Test that different encryption keys produce different ciphertext."""
|
||||||
|
data = b"Same data encrypted with different keys"
|
||||||
|
|
||||||
|
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key1"):
|
||||||
|
cipher1 = backup.get_cipher()
|
||||||
|
encrypted1 = cipher1.encrypt(data)
|
||||||
|
|
||||||
|
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key2"):
|
||||||
|
cipher2 = backup.get_cipher()
|
||||||
|
encrypted2 = cipher2.encrypt(data)
|
||||||
|
|
||||||
|
assert encrypted1 != encrypted2
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_encryption_key_raises_error():
|
||||||
|
"""Test that missing encryption key raises ValueError."""
|
||||||
|
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", ""):
|
||||||
|
with pytest.raises(ValueError, match="BACKUP_ENCRYPTION_KEY not set"):
|
||||||
|
backup.get_cipher()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_tarball_with_files(sample_files):
|
||||||
|
"""Test creating tarball from directory with files."""
|
||||||
|
notes_dir = settings.FILE_STORAGE_DIR / "notes"
|
||||||
|
tarball_bytes = backup.create_tarball(notes_dir)
|
||||||
|
|
||||||
|
assert len(tarball_bytes) > 0
|
||||||
|
|
||||||
|
# Verify it's a valid gzipped tarball
|
||||||
|
tar_buffer = io.BytesIO(tarball_bytes)
|
||||||
|
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||||
|
members = tar.getmembers()
|
||||||
|
filenames = [m.name for m in members if m.isfile()]
|
||||||
|
assert len(filenames) >= 2
|
||||||
|
assert any("note1.md" in f for f in filenames)
|
||||||
|
assert any("note2.md" in f for f in filenames)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_tarball_nonexistent_directory():
|
||||||
|
"""Test creating tarball from nonexistent directory."""
|
||||||
|
nonexistent = settings.FILE_STORAGE_DIR / "does_not_exist"
|
||||||
|
tarball_bytes = backup.create_tarball(nonexistent)
|
||||||
|
|
||||||
|
assert tarball_bytes == b""
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_tarball_empty_directory():
|
||||||
|
"""Test creating tarball from empty directory."""
|
||||||
|
empty_dir = settings.FILE_STORAGE_DIR / "empty"
|
||||||
|
empty_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
tarball_bytes = backup.create_tarball(empty_dir)
|
||||||
|
|
||||||
|
# Should create tarball with just the directory entry
|
||||||
|
assert len(tarball_bytes) > 0
|
||||||
|
tar_buffer = io.BytesIO(tarball_bytes)
|
||||||
|
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||||
|
members = tar.getmembers()
|
||||||
|
assert len(members) >= 1
|
||||||
|
assert members[0].isdir()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_unencrypted_success(sample_files, backup_settings):
|
||||||
|
"""Test successful sync of unencrypted directory."""
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
mock_run.return_value = Mock(stdout="Synced files", returncode=0)
|
||||||
|
|
||||||
|
comics_path = settings.FILE_STORAGE_DIR / "comics"
|
||||||
|
result = backup.sync_unencrypted_directory(comics_path)
|
||||||
|
|
||||||
|
assert result["synced"] is True
|
||||||
|
assert result["directory"] == comics_path
|
||||||
|
assert "s3_uri" in result
|
||||||
|
assert "test-bucket" in result["s3_uri"]
|
||||||
|
assert "test-prefix/comics" in result["s3_uri"]
|
||||||
|
|
||||||
|
# Verify aws s3 sync was called correctly
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
call_args = mock_run.call_args[0][0]
|
||||||
|
assert call_args[0] == "aws"
|
||||||
|
assert call_args[1] == "s3"
|
||||||
|
assert call_args[2] == "sync"
|
||||||
|
assert "--delete" in call_args
|
||||||
|
assert "--region" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_unencrypted_nonexistent_directory(backup_settings):
|
||||||
|
"""Test syncing nonexistent directory."""
|
||||||
|
nonexistent_path = settings.FILE_STORAGE_DIR / "does_not_exist"
|
||||||
|
result = backup.sync_unencrypted_directory(nonexistent_path)
|
||||||
|
|
||||||
|
assert result["synced"] is False
|
||||||
|
assert result["reason"] == "directory_not_found"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_unencrypted_aws_cli_failure(sample_files, backup_settings):
|
||||||
|
"""Test handling of AWS CLI failure."""
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
mock_run.side_effect = subprocess.CalledProcessError(
|
||||||
|
1, "aws", stderr="AWS CLI error"
|
||||||
|
)
|
||||||
|
|
||||||
|
comics_path = settings.FILE_STORAGE_DIR / "comics"
|
||||||
|
result = backup.sync_unencrypted_directory(comics_path)
|
||||||
|
|
||||||
|
assert result["synced"] is False
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_encrypted_success(
|
||||||
|
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||||
|
):
|
||||||
|
"""Test successful encrypted backup."""
|
||||||
|
result = backup.backup_encrypted_directory(get_test_path("emails"))
|
||||||
|
|
||||||
|
assert result["uploaded"] is True
|
||||||
|
assert result["size_bytes"] > 0
|
||||||
|
assert result["s3_key"].endswith("emails.tar.gz.enc")
|
||||||
|
|
||||||
|
call_kwargs = mock_s3_client.put_object.call_args[1]
|
||||||
|
assert call_kwargs["Bucket"] == "test-bucket"
|
||||||
|
assert call_kwargs["ServerSideEncryption"] == "AES256"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_encrypted_nonexistent_directory(
|
||||||
|
mock_s3_client, backup_settings, get_test_path
|
||||||
|
):
|
||||||
|
"""Test backing up nonexistent directory."""
|
||||||
|
result = backup.backup_encrypted_directory(get_test_path("does_not_exist"))
|
||||||
|
|
||||||
|
assert result["uploaded"] is False
|
||||||
|
assert result["reason"] == "directory_not_found"
|
||||||
|
mock_s3_client.put_object.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_encrypted_empty_directory(
|
||||||
|
mock_s3_client, backup_settings, get_test_path
|
||||||
|
):
|
||||||
|
"""Test backing up empty directory."""
|
||||||
|
empty_dir = get_test_path("empty_encrypted")
|
||||||
|
empty_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
result = backup.backup_encrypted_directory(empty_dir)
|
||||||
|
assert "uploaded" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_encrypted_s3_failure(
|
||||||
|
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||||
|
):
|
||||||
|
"""Test handling of S3 upload failure."""
|
||||||
|
mock_s3_client.put_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, "PutObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = backup.backup_encrypted_directory(get_test_path("notes"))
|
||||||
|
assert result["uploaded"] is False
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_encrypted_data_integrity(
|
||||||
|
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||||
|
):
|
||||||
|
"""Test that encrypted backup maintains data integrity through full cycle."""
|
||||||
|
result = backup.backup_encrypted_directory(get_test_path("notes"))
|
||||||
|
assert result["uploaded"] is True
|
||||||
|
|
||||||
|
# Decrypt uploaded data
|
||||||
|
cipher = backup.get_cipher()
|
||||||
|
encrypted_data = mock_s3_client.put_object.call_args[1]["Body"]
|
||||||
|
decrypted_tarball = cipher.decrypt(encrypted_data)
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
tar_buffer = io.BytesIO(decrypted_tarball)
|
||||||
|
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||||
|
note1_found = False
|
||||||
|
for member in tar.getmembers():
|
||||||
|
if member.name.endswith("note1.md") and member.isfile():
|
||||||
|
content = tar.extractfile(member).read().decode()
|
||||||
|
assert "Content of notes/note1.md" in content
|
||||||
|
note1_found = True
|
||||||
|
assert note1_found, "note1.md not found in tarball"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_disabled():
|
||||||
|
"""Test that backup returns early when disabled."""
|
||||||
|
with patch.object(settings, "S3_BACKUP_ENABLED", False):
|
||||||
|
result = backup.backup_all_to_s3()
|
||||||
|
|
||||||
|
assert result["status"] == "disabled"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_full_execution(sample_files, mock_s3_client, backup_settings):
|
||||||
|
"""Test full backup execution dispatches tasks for all directories."""
|
||||||
|
with patch.object(backup, "backup_to_s3") as mock_task:
|
||||||
|
mock_task.delay = Mock()
|
||||||
|
|
||||||
|
result = backup.backup_all_to_s3()
|
||||||
|
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert "message" in result
|
||||||
|
|
||||||
|
# Verify task was queued for each storage directory
|
||||||
|
assert mock_task.delay.call_count == len(settings.storage_dirs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_handles_partial_failures(
|
||||||
|
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||||
|
):
|
||||||
|
"""Test that backup continues even if some directories fail."""
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
mock_run.side_effect = subprocess.CalledProcessError(
|
||||||
|
1, "aws", stderr="Sync failed"
|
||||||
|
)
|
||||||
|
result = backup.sync_unencrypted_directory(get_test_path("comics"))
|
||||||
|
|
||||||
|
assert result["synced"] is False
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_key_different_runs_different_ciphertext():
|
||||||
|
"""Test that Fernet produces different ciphertext each run (due to nonce)."""
|
||||||
|
data = b"Consistent data"
|
||||||
|
|
||||||
|
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "same-key"):
|
||||||
|
cipher = backup.get_cipher()
|
||||||
|
encrypted1 = cipher.encrypt(data)
|
||||||
|
encrypted2 = cipher.encrypt(data)
|
||||||
|
|
||||||
|
# Should be different due to random nonce, but both should decrypt to same value
|
||||||
|
assert encrypted1 != encrypted2
|
||||||
|
|
||||||
|
decrypted1 = cipher.decrypt(encrypted1)
|
||||||
|
decrypted2 = cipher.decrypt(encrypted2)
|
||||||
|
assert decrypted1 == decrypted2 == data
|
||||||
|
|
||||||
|
|
||||||
|
def test_key_derivation_consistency():
|
||||||
|
"""Test that same password produces same encryption key."""
|
||||||
|
password = "test-password"
|
||||||
|
|
||||||
|
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", password):
|
||||||
|
cipher1 = backup.get_cipher()
|
||||||
|
cipher2 = backup.get_cipher()
|
||||||
|
|
||||||
|
# Both should be able to decrypt each other's ciphertext
|
||||||
|
data = b"Test data"
|
||||||
|
encrypted = cipher1.encrypt(data)
|
||||||
|
decrypted = cipher2.decrypt(encrypted)
|
||||||
|
assert decrypted == data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dir_name,is_private",
|
||||||
|
[
|
||||||
|
("emails", True),
|
||||||
|
("notes", True),
|
||||||
|
("photos", True),
|
||||||
|
("comics", False),
|
||||||
|
("ebooks", False),
|
||||||
|
("webpages", False),
|
||||||
|
("lesswrong", False),
|
||||||
|
("chunks", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_directory_encryption_classification(dir_name, is_private, backup_settings):
|
||||||
|
"""Test that directories are correctly classified as encrypted or not."""
|
||||||
|
# Create a mock PRIVATE_DIRS list
|
||||||
|
private_dirs = ["emails", "notes", "photos"]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
settings, "PRIVATE_DIRS", [settings.FILE_STORAGE_DIR / d for d in private_dirs]
|
||||||
|
):
|
||||||
|
test_path = settings.FILE_STORAGE_DIR / dir_name
|
||||||
|
is_in_private = test_path in settings.PRIVATE_DIRS
|
||||||
|
|
||||||
|
assert is_in_private == is_private
|
||||||
@ -3,6 +3,7 @@ from datetime import datetime, timezone
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
|
DiscordBotUser,
|
||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
DiscordServer,
|
DiscordServer,
|
||||||
@ -12,12 +13,25 @@ from memory.workers.tasks import discord
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_discord_user(db_session):
|
def discord_bot_user(db_session):
|
||||||
|
bot = DiscordBotUser.create_with_api_key(
|
||||||
|
discord_users=[],
|
||||||
|
name="Test Bot",
|
||||||
|
email="bot@example.com",
|
||||||
|
)
|
||||||
|
db_session.add(bot)
|
||||||
|
db_session.commit()
|
||||||
|
return bot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord_user(db_session, discord_bot_user):
|
||||||
"""Create a Discord user for testing."""
|
"""Create a Discord user for testing."""
|
||||||
user = DiscordUser(
|
user = DiscordUser(
|
||||||
id=123456789,
|
id=123456789,
|
||||||
username="testuser",
|
username="testuser",
|
||||||
ignore_messages=False,
|
ignore_messages=False,
|
||||||
|
system_user_id=discord_bot_user.id,
|
||||||
)
|
)
|
||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|||||||
@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer
|
from memory.common.db.models import (
|
||||||
|
ScheduledLLMCall,
|
||||||
|
DiscordBotUser,
|
||||||
|
DiscordUser,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordServer,
|
||||||
|
)
|
||||||
from memory.workers.tasks import scheduled_calls
|
from memory.workers.tasks import scheduled_calls
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_user(db_session):
|
def sample_user(db_session):
|
||||||
"""Create a sample user for testing."""
|
"""Create a sample user for testing."""
|
||||||
user = HumanUser.create_with_password(
|
user = DiscordBotUser.create_with_api_key(
|
||||||
name="testuser",
|
discord_users=[],
|
||||||
email="test@example.com",
|
name="testbot",
|
||||||
password="password",
|
email="bot@example.com",
|
||||||
)
|
)
|
||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
|||||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||||
|
|
||||||
mock_send_dm.assert_called_once_with(
|
mock_send_dm.assert_called_once_with(
|
||||||
|
pending_scheduled_call.user_id,
|
||||||
"testuser", # username, not ID
|
"testuser", # username, not ID
|
||||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||||
)
|
)
|
||||||
@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
|||||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||||
|
|
||||||
mock_broadcast.assert_called_once_with(
|
mock_broadcast.assert_called_once_with(
|
||||||
|
completed_scheduled_call.user_id,
|
||||||
"test-channel", # channel name, not ID
|
"test-channel", # channel name, not ID
|
||||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||||
)
|
)
|
||||||
@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
|
|||||||
|
|
||||||
# Verify the message was truncated
|
# Verify the message was truncated
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
message = args[1]
|
assert args[0] == pending_scheduled_call.user_id
|
||||||
|
message = args[2]
|
||||||
assert len(message) <= 1950 # Should be truncated
|
assert len(message) <= 1950 # Should be truncated
|
||||||
assert message.endswith("... (response truncated)")
|
assert message.endswith("... (response truncated)")
|
||||||
|
|
||||||
@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
|
|||||||
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
||||||
|
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
message = args[1]
|
assert args[0] == pending_scheduled_call.user_id
|
||||||
|
message = args[2]
|
||||||
assert not message.endswith("... (response truncated)")
|
assert not message.endswith("... (response truncated)")
|
||||||
assert "This is a normal length response." in message
|
assert "This is a normal length response." in message
|
||||||
|
|
||||||
@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
|||||||
mock_discord_user.username = "testuser"
|
mock_discord_user.username = "testuser"
|
||||||
|
|
||||||
mock_call = Mock()
|
mock_call = Mock()
|
||||||
|
mock_call.user_id = 987
|
||||||
mock_call.topic = topic
|
mock_call.topic = topic
|
||||||
mock_call.model = model
|
mock_call.model = model
|
||||||
mock_call.discord_user = mock_discord_user
|
mock_call.discord_user = mock_discord_user
|
||||||
@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
|||||||
|
|
||||||
# Get the actual message that was sent
|
# Get the actual message that was sent
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
actual_message = args[1]
|
assert args[0] == mock_call.user_id
|
||||||
|
actual_message = args[2]
|
||||||
|
|
||||||
# Verify all expected parts are in the message
|
# Verify all expected parts are in the message
|
||||||
for expected_part in expected_in_message:
|
for expected_part in expected_in_message:
|
||||||
|
|||||||
182
tools/backup_databases.sh
Normal file
182
tools/backup_databases.sh
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Backup Postgres and Qdrant databases to S3
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Install AWS CLI if not present (postgres:15 image doesn't include it)
|
||||||
|
if ! command -v aws >/dev/null 2>&1; then
|
||||||
|
echo "Installing AWS CLI, wget, and jq..."
|
||||||
|
apt-get update -qq && apt-get install -y -qq awscli wget jq >/dev/null 2>&1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Configuration - read from environment or use defaults
|
||||||
|
BUCKET="${S3_BACKUP_BUCKET:-equistamp-memory-backup}"
|
||||||
|
PREFIX="${S3_BACKUP_PREFIX:-Daniel}/databases"
|
||||||
|
REGION="${S3_BACKUP_REGION:-eu-central-1}"
|
||||||
|
PASSWORD="${BACKUP_ENCRYPTION_KEY:?BACKUP_ENCRYPTION_KEY not set}"
|
||||||
|
MAX_BACKUPS="${MAX_BACKUPS:-30}" # Keep last N backups
|
||||||
|
|
||||||
|
# Service names (docker-compose network)
|
||||||
|
POSTGRES_HOST="${POSTGRES_HOST:-postgres}"
|
||||||
|
POSTGRES_USER="${POSTGRES_USER:-kb}"
|
||||||
|
POSTGRES_DB="${POSTGRES_DB:-kb}"
|
||||||
|
QDRANT_URL="${QDRANT_URL:-http://qdrant:6333}"
|
||||||
|
|
||||||
|
# Timestamp for backups
|
||||||
|
DATE=$(date +%Y%m%d-%H%M%S)
|
||||||
|
DATE_SIMPLE=$(date +%Y%m%d)
|
||||||
|
|
||||||
|
log() {
|
||||||
|
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
error() {
|
||||||
|
echo "[$(date '+%Y-%m-%d %H:%M:%S')] ERROR: $*" >&2
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clean old backups - keep only last N
|
||||||
|
cleanup_old_backups() {
|
||||||
|
local prefix=$1
|
||||||
|
local pattern=$2 # e.g., "postgres-" or "qdrant-"
|
||||||
|
|
||||||
|
log "Checking for old ${pattern} backups to clean up..."
|
||||||
|
|
||||||
|
# List all backups matching pattern, sorted by date (oldest first)
|
||||||
|
local backups
|
||||||
|
backups=$(aws s3 ls "s3://${BUCKET}/${prefix}/" --region "${REGION}" | \
|
||||||
|
grep "${pattern}" | \
|
||||||
|
awk '{print $4}' | \
|
||||||
|
sort)
|
||||||
|
|
||||||
|
local count=$(echo "$backups" | wc -l)
|
||||||
|
|
||||||
|
if [ "$count" -le "$MAX_BACKUPS" ]; then
|
||||||
|
log "Found ${count} ${pattern} backups (max: ${MAX_BACKUPS}), no cleanup needed"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
local to_delete=$((count - MAX_BACKUPS))
|
||||||
|
log "Found ${count} ${pattern} backups, deleting ${to_delete} oldest..."
|
||||||
|
|
||||||
|
echo "$backups" | head -n "$to_delete" | while read -r file; do
|
||||||
|
if [ -n "$file" ]; then
|
||||||
|
log "Deleting old backup: ${file}"
|
||||||
|
aws s3 rm "s3://${BUCKET}/${prefix}/${file}" --region "${REGION}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
# Backup Postgres
|
||||||
|
backup_postgres() {
|
||||||
|
log "Starting Postgres backup..."
|
||||||
|
|
||||||
|
local output_path="s3://${BUCKET}/${PREFIX}/postgres-${DATE_SIMPLE}.sql.gz.enc"
|
||||||
|
|
||||||
|
# Use pg_dump directly with service name (no docker exec needed)
|
||||||
|
export PGPASSWORD=$(cat "${POSTGRES_PASSWORD_FILE}")
|
||||||
|
if pg_dump -h "${POSTGRES_HOST}" -U "${POSTGRES_USER}" "${POSTGRES_DB}" 2>/dev/null | \
|
||||||
|
gzip | \
|
||||||
|
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
|
||||||
|
aws s3 cp - "${output_path}" --region "${REGION}"; then
|
||||||
|
log "Postgres backup completed: ${output_path}"
|
||||||
|
unset PGPASSWORD
|
||||||
|
cleanup_old_backups "${PREFIX}" "postgres-"
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
error "Postgres backup failed"
|
||||||
|
unset PGPASSWORD
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Backup Qdrant
|
||||||
|
backup_qdrant() {
|
||||||
|
log "Starting Qdrant backup..."
|
||||||
|
|
||||||
|
# Create snapshot via HTTP API (no docker exec needed)
|
||||||
|
local snapshot_response
|
||||||
|
if ! snapshot_response=$(wget -q -O - --post-data='{}' \
|
||||||
|
--header='Content-Type: application/json' \
|
||||||
|
"${QDRANT_URL}/snapshots" 2>/dev/null); then
|
||||||
|
error "Failed to create Qdrant snapshot"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
local snapshot_name
|
||||||
|
# Parse snapshot name - wget/busybox may not have jq, so use grep/sed
|
||||||
|
if command -v jq >/dev/null 2>&1; then
|
||||||
|
snapshot_name=$(echo "${snapshot_response}" | jq -r '.result.name // empty')
|
||||||
|
else
|
||||||
|
# Fallback: parse JSON without jq (fragile but works for simple case)
|
||||||
|
snapshot_name=$(echo "${snapshot_response}" | grep -o '"name":"[^"]*"' | cut -d'"' -f4)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "${snapshot_name}" ]; then
|
||||||
|
error "Could not extract snapshot name from response: ${snapshot_response}"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "Created Qdrant snapshot: ${snapshot_name}"
|
||||||
|
|
||||||
|
# Download snapshot and upload to S3
|
||||||
|
local output_path="s3://${BUCKET}/${PREFIX}/qdrant-${DATE_SIMPLE}.snapshot.enc"
|
||||||
|
|
||||||
|
if wget -q -O - "${QDRANT_URL}/snapshots/${snapshot_name}" | \
|
||||||
|
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
|
||||||
|
aws s3 cp - "${output_path}" --region "${REGION}"; then
|
||||||
|
log "Qdrant backup completed: ${output_path}"
|
||||||
|
|
||||||
|
# Delete the snapshot from Qdrant
|
||||||
|
if wget -q -O - --method=DELETE \
|
||||||
|
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1; then
|
||||||
|
log "Deleted Qdrant snapshot: ${snapshot_name}"
|
||||||
|
else
|
||||||
|
error "Failed to delete Qdrant snapshot: ${snapshot_name}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cleanup_old_backups "${PREFIX}" "qdrant-"
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
error "Qdrant backup failed"
|
||||||
|
|
||||||
|
# Try to clean up snapshot
|
||||||
|
wget -q -O - --method=DELETE \
|
||||||
|
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1 || true
|
||||||
|
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main execution
|
||||||
|
main() {
|
||||||
|
log "Database backup started"
|
||||||
|
|
||||||
|
local postgres_result=0
|
||||||
|
local qdrant_result=0
|
||||||
|
|
||||||
|
# Backup Postgres
|
||||||
|
if ! backup_postgres; then
|
||||||
|
postgres_result=1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Backup Qdrant
|
||||||
|
if ! backup_qdrant; then
|
||||||
|
qdrant_result=1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
if [ $postgres_result -eq 0 ] && [ $qdrant_result -eq 0 ]; then
|
||||||
|
log "All database backups completed successfully"
|
||||||
|
return 0
|
||||||
|
elif [ $postgres_result -ne 0 ] && [ $qdrant_result -ne 0 ]; then
|
||||||
|
error "All database backups failed"
|
||||||
|
return 1
|
||||||
|
else
|
||||||
|
error "Some database backups failed (Postgres: ${postgres_result}, Qdrant: ${qdrant_result})"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run main function
|
||||||
|
main
|
||||||
|
|
||||||
@ -51,6 +51,8 @@ from memory.common.celery_app import (
|
|||||||
UPDATE_METADATA_FOR_SOURCE_ITEMS,
|
UPDATE_METADATA_FOR_SOURCE_ITEMS,
|
||||||
SETUP_GIT_NOTES,
|
SETUP_GIT_NOTES,
|
||||||
TRACK_GIT_CHANGES,
|
TRACK_GIT_CHANGES,
|
||||||
|
BACKUP_TO_S3_DIRECTORY,
|
||||||
|
BACKUP_ALL,
|
||||||
app,
|
app,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -97,6 +99,10 @@ TASK_MAPPINGS = {
|
|||||||
"setup_git_notes": SETUP_GIT_NOTES,
|
"setup_git_notes": SETUP_GIT_NOTES,
|
||||||
"track_git_changes": TRACK_GIT_CHANGES,
|
"track_git_changes": TRACK_GIT_CHANGES,
|
||||||
},
|
},
|
||||||
|
"backup": {
|
||||||
|
"backup_to_s3_directory": BACKUP_TO_S3_DIRECTORY,
|
||||||
|
"backup_all": BACKUP_ALL,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
QUEUE_MAPPINGS = {
|
QUEUE_MAPPINGS = {
|
||||||
"email": "email",
|
"email": "email",
|
||||||
@ -177,6 +183,28 @@ def execute_task(ctx, category: str, task_name: str, **kwargs):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.group()
|
||||||
|
@click.pass_context
|
||||||
|
def backup(ctx):
|
||||||
|
"""Backup-related tasks."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@backup.command("all")
|
||||||
|
@click.pass_context
|
||||||
|
def backup_all(ctx):
|
||||||
|
"""Backup all directories."""
|
||||||
|
execute_task(ctx, "backup", "backup_all")
|
||||||
|
|
||||||
|
|
||||||
|
@backup.command("path")
|
||||||
|
@click.option("--path", required=True, help="Path to backup")
|
||||||
|
@click.pass_context
|
||||||
|
def backup_to_s3_directory(ctx, path):
|
||||||
|
"""Backup a specific path."""
|
||||||
|
execute_task(ctx, "backup", "backup_to_s3_directory", path=path)
|
||||||
|
|
||||||
|
|
||||||
@cli.group()
|
@cli.group()
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def email(ctx):
|
def email(ctx):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user