mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 08:14:05 +01:00
save images
This commit is contained in:
parent
e95a082147
commit
a4f42e656a
@ -0,0 +1,29 @@
|
|||||||
|
"""store discord images
|
||||||
|
|
||||||
|
Revision ID: 1954477b25f4
|
||||||
|
Revises: 2024235e37e7
|
||||||
|
Create Date: 2025-11-02 10:14:48.334934
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "1954477b25f4"
|
||||||
|
down_revision: Union[str, None] = "2024235e37e7"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"discord_message", sa.Column("images", sa.ARRAY(sa.Text()), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("discord_message", "images")
|
||||||
@ -5,15 +5,16 @@ SQLAdmin views for the knowledge base database models.
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sqladmin import Admin, ModelView
|
from sqladmin import Admin, ModelView
|
||||||
|
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
AgentObservation,
|
AgentObservation,
|
||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
BlogPost,
|
BlogPost,
|
||||||
Book,
|
Book,
|
||||||
BookSection,
|
BookSection,
|
||||||
ScheduledLLMCall,
|
|
||||||
Chunk,
|
Chunk,
|
||||||
Comic,
|
Comic,
|
||||||
|
DiscordMessage,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
ForumPost,
|
ForumPost,
|
||||||
@ -21,6 +22,7 @@ from memory.common.db.models import (
|
|||||||
MiscDoc,
|
MiscDoc,
|
||||||
Note,
|
Note,
|
||||||
Photo,
|
Photo,
|
||||||
|
ScheduledLLMCall,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
@ -153,6 +155,17 @@ class BookAdmin(ModelView, model=Book):
|
|||||||
column_searchable_list = ["title", "author", "id"]
|
column_searchable_list = ["title", "author", "id"]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"content",
|
||||||
|
"images",
|
||||||
|
"sent_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["content", "id", "images"]
|
||||||
|
column_sortable_list = ["sent_at"]
|
||||||
|
|
||||||
|
|
||||||
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
||||||
column_list = [
|
column_list = [
|
||||||
"id",
|
"id",
|
||||||
@ -310,6 +323,7 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(ForumPostAdmin)
|
admin.add_view(ForumPostAdmin)
|
||||||
admin.add_view(ComicAdmin)
|
admin.add_view(ComicAdmin)
|
||||||
admin.add_view(PhotoAdmin)
|
admin.add_view(PhotoAdmin)
|
||||||
|
admin.add_view(DiscordMessageAdmin)
|
||||||
admin.add_view(UserAdmin)
|
admin.add_view(UserAdmin)
|
||||||
admin.add_view(DiscordUserAdmin)
|
admin.add_view(DiscordUserAdmin)
|
||||||
admin.add_view(DiscordServerAdmin)
|
admin.add_view(DiscordServerAdmin)
|
||||||
|
|||||||
@ -301,6 +301,7 @@ class DiscordMessage(SourceItem):
|
|||||||
BigInteger, nullable=True
|
BigInteger, nullable=True
|
||||||
) # Discord thread snowflake ID if in thread
|
) # Discord thread snowflake ID if in thread
|
||||||
edited_at = Column(DateTime(timezone=True), nullable=True)
|
edited_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
images = Column(ARRAY(Text), nullable=True) # List of image URLs
|
||||||
|
|
||||||
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
||||||
server = relationship("DiscordServer", foreign_keys=[server_id])
|
server = relationship("DiscordServer", foreign_keys=[server_id])
|
||||||
@ -358,6 +359,20 @@ class DiscordMessage(SourceItem):
|
|||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
||||||
|
|
||||||
|
def as_content(self):
|
||||||
|
"""Return message content ready for LLM (text + images from disk)."""
|
||||||
|
content = {"text": self.title, "images": []}
|
||||||
|
for path in cast(list[str] | None, self.images) or []:
|
||||||
|
try:
|
||||||
|
full_path = settings.FILE_STORAGE_DIR / path
|
||||||
|
if full_path.exists():
|
||||||
|
image = Image.open(full_path)
|
||||||
|
content["images"].append(image)
|
||||||
|
except Exception:
|
||||||
|
pass # Skip failed image loads
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "discord_message",
|
"polymorphic_identity": "discord_message",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -171,13 +171,17 @@ class Message:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def user(
|
def user(
|
||||||
text: str | None = None, tool_result: ToolResultContent | None = None
|
text: str | None = None,
|
||||||
|
images: list[Image.Image] | None = None,
|
||||||
|
tool_result: ToolResultContent | None = None,
|
||||||
) -> "Message":
|
) -> "Message":
|
||||||
parts = []
|
parts = []
|
||||||
if text:
|
if text:
|
||||||
parts.append(TextContent(text=text))
|
parts.append(TextContent(text=text))
|
||||||
if tool_result:
|
if tool_result:
|
||||||
parts.append(tool_result)
|
parts.append(tool_result)
|
||||||
|
for image in images or []:
|
||||||
|
parts.append(ImageContent(image=image))
|
||||||
return Message(role=MessageRole.USER, content=parts)
|
return Message(role=MessageRole.USER, content=parts)
|
||||||
|
|
||||||
|
|
||||||
@ -625,8 +629,16 @@ class BaseLLMProvider(ABC):
|
|||||||
tool_calls=tool_calls or None,
|
tool_calls=tool_calls or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_messages(self, messages) -> list[Message]:
|
def as_messages(self, messages: list[dict[str, Any] | str]) -> list[Message]:
|
||||||
return [Message.user(text=msg) for msg in messages]
|
def make_message(msg: dict[str, Any] | str) -> Message:
|
||||||
|
if isinstance(msg, str):
|
||||||
|
return Message.user(text=msg)
|
||||||
|
elif isinstance(msg, dict):
|
||||||
|
return Message.user(text=msg["text"], images=msg.get("images"))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown message type: {type(msg)}")
|
||||||
|
|
||||||
|
return [make_message(msg) for msg in messages]
|
||||||
|
|
||||||
|
|
||||||
def create_provider(
|
def create_provider(
|
||||||
|
|||||||
@ -73,6 +73,9 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
|
|||||||
NOTES_STORAGE_DIR = pathlib.Path(
|
NOTES_STORAGE_DIR = pathlib.Path(
|
||||||
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
|
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
|
||||||
)
|
)
|
||||||
|
DISCORD_STORAGE_DIR = pathlib.Path(
|
||||||
|
os.getenv("DISCORD_STORAGE_DIR", FILE_STORAGE_DIR / "discord")
|
||||||
|
)
|
||||||
PRIVATE_DIRS = [
|
PRIVATE_DIRS = [
|
||||||
EMAIL_STORAGE_DIR,
|
EMAIL_STORAGE_DIR,
|
||||||
NOTES_STORAGE_DIR,
|
NOTES_STORAGE_DIR,
|
||||||
@ -88,6 +91,7 @@ storage_dirs = [
|
|||||||
PHOTO_STORAGE_DIR,
|
PHOTO_STORAGE_DIR,
|
||||||
WEBPAGE_STORAGE_DIR,
|
WEBPAGE_STORAGE_DIR,
|
||||||
NOTES_STORAGE_DIR,
|
NOTES_STORAGE_DIR,
|
||||||
|
DISCORD_STORAGE_DIR,
|
||||||
]
|
]
|
||||||
for dir in storage_dirs:
|
for dir in storage_dirs:
|
||||||
dir.mkdir(parents=True, exist_ok=True)
|
dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
@ -24,6 +24,32 @@ from memory.workers.tasks.discord import add_discord_message, edit_discord_messa
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def process_mentions(session: Session | scoped_session, message: str) -> str:
|
||||||
|
"""Convert username mentions (<@username>) to ID mentions (<@123456>)"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
def replace_mention(match):
|
||||||
|
mention_content = match.group(1)
|
||||||
|
# If it's already numeric, leave it alone
|
||||||
|
if mention_content.isdigit():
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
# Look up username in database
|
||||||
|
user = (
|
||||||
|
session.query(DiscordUser)
|
||||||
|
.filter(DiscordUser.username == mention_content)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user:
|
||||||
|
return f"<@{user.id}>"
|
||||||
|
|
||||||
|
# If user not found, return original
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
return re.sub(r"<@([^>]+)>", replace_mention, message)
|
||||||
|
|
||||||
|
|
||||||
# Pure functions for Discord entity creation/updates
|
# Pure functions for Discord entity creation/updates
|
||||||
def create_or_update_server(
|
def create_or_update_server(
|
||||||
session: Session | scoped_session, guild: discord.Guild | None
|
session: Session | scoped_session, guild: discord.Guild | None
|
||||||
@ -233,6 +259,13 @@ class MessageCollector(commands.Bot):
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
# Extract image URLs from attachments
|
||||||
|
image_urls = [
|
||||||
|
att.url
|
||||||
|
for att in message.attachments
|
||||||
|
if att.content_type and att.content_type.startswith("image/")
|
||||||
|
]
|
||||||
|
|
||||||
# Queue the message for processing
|
# Queue the message for processing
|
||||||
add_discord_message.delay(
|
add_discord_message.delay(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
@ -245,6 +278,7 @@ class MessageCollector(commands.Bot):
|
|||||||
message_reference_id=message.reference.message_id
|
message_reference_id=message.reference.message_id
|
||||||
if message.reference
|
if message.reference
|
||||||
else None,
|
else None,
|
||||||
|
image_urls=image_urls,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error queuing message {message.id}: {e}")
|
logger.error(f"Error queuing message {message.id}: {e}")
|
||||||
@ -373,7 +407,11 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"User {user_identifier} not found")
|
logger.error(f"User {user_identifier} not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
await user.send(message)
|
# Post-process mentions to convert usernames to IDs
|
||||||
|
with make_session() as session:
|
||||||
|
processed_message = process_mentions(session, message)
|
||||||
|
|
||||||
|
await user.send(processed_message)
|
||||||
logger.info(f"Sent DM to {user_identifier}")
|
logger.info(f"Sent DM to {user_identifier}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -413,7 +451,11 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"Channel {channel_name} not found")
|
logger.error(f"Channel {channel_name} not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
await channel.send(message)
|
# Post-process mentions to convert usernames to IDs
|
||||||
|
with make_session() as session:
|
||||||
|
processed_message = process_mentions(session, message)
|
||||||
|
|
||||||
|
await channel.send(processed_message)
|
||||||
logger.info(f"Sent message to channel {channel_name}")
|
logger.info(f"Sent message to channel {channel_name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -4,11 +4,13 @@ Celery tasks for Discord message processing.
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import requests
|
||||||
from sqlalchemy import exc as sqlalchemy_exc
|
from sqlalchemy import exc as sqlalchemy_exc
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
@ -34,6 +36,37 @@ from memory.workers.tasks.content_processing import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_save_images(image_urls: list[str], message_id: int) -> list[str]:
|
||||||
|
"""Download images from URLs and save to disk. Returns relative file paths."""
|
||||||
|
image_dir = settings.DISCORD_STORAGE_DIR / str(message_id)
|
||||||
|
image_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
saved_paths = []
|
||||||
|
for url in image_urls:
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Generate filename
|
||||||
|
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||||
|
ext = pathlib.Path(url).suffix or ".jpg"
|
||||||
|
ext = ext.split("?")[0]
|
||||||
|
filename = f"{url_hash}{ext}"
|
||||||
|
local_path = image_dir / filename
|
||||||
|
|
||||||
|
# Save image
|
||||||
|
local_path.write_bytes(response.content)
|
||||||
|
|
||||||
|
# Store relative path from FILE_STORAGE_DIR
|
||||||
|
relative_path = local_path.relative_to(settings.FILE_STORAGE_DIR)
|
||||||
|
saved_paths.append(str(relative_path))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download/save image from {url}: {e}")
|
||||||
|
|
||||||
|
return saved_paths
|
||||||
|
|
||||||
|
|
||||||
def get_prev(
|
def get_prev(
|
||||||
session: Session | scoped_session, channel_id: int, sent_at: datetime
|
session: Session | scoped_session, channel_id: int, sent_at: datetime
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
@ -87,8 +120,12 @@ def call_llm(
|
|||||||
message.channel and message.channel.id,
|
message.channel and message.channel.id,
|
||||||
max_messages=10,
|
max_messages=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build message list: previous messages + current message + any extra text msgs
|
||||||
|
message_content = [m.as_content() for m in messages + [message]] + msgs
|
||||||
|
|
||||||
return provider.run_with_tools(
|
return provider.run_with_tools(
|
||||||
messages=provider.as_messages([m.title for m in messages] + msgs),
|
messages=provider.as_messages(message_content),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
@ -226,6 +263,7 @@ def add_discord_message(
|
|||||||
server_id: int | None = None,
|
server_id: int | None = None,
|
||||||
recipient_id: int | None = None,
|
recipient_id: int | None = None,
|
||||||
message_reference_id: int | None = None,
|
message_reference_id: int | None = None,
|
||||||
|
image_urls: list[str] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Add a Discord message to the database.
|
Add a Discord message to the database.
|
||||||
@ -237,6 +275,11 @@ def add_discord_message(
|
|||||||
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
|
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
|
||||||
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
|
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
# Download and save images to disk
|
||||||
|
saved_image_paths = []
|
||||||
|
if image_urls:
|
||||||
|
saved_image_paths = download_and_save_images(image_urls, message_id)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
discord_message = DiscordMessage(
|
discord_message = DiscordMessage(
|
||||||
modality="text",
|
modality="text",
|
||||||
@ -250,6 +293,7 @@ def add_discord_message(
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
message_type="reply" if message_reference_id else "default",
|
message_type="reply" if message_reference_id else "default",
|
||||||
reply_to_message_id=message_reference_id,
|
reply_to_message_id=message_reference_id,
|
||||||
|
images=saved_image_paths or None,
|
||||||
)
|
)
|
||||||
existing_msg = check_content_exists(
|
existing_msg = check_content_exists(
|
||||||
session, DiscordMessage, message_id=message_id, sha256=content_hash
|
session, DiscordMessage, message_id=message_id, sha256=content_hash
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user