diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index 132317e..be84e43 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -10,21 +10,28 @@ COPY src/ ./src/ # Install dependencies RUN apt-get update && apt-get install -y \ libpq-dev gcc pandoc \ - texlive-full texlive-fonts-recommended texlive-plain-generic \ + texlive-xetex texlive-fonts-recommended texlive-plain-generic \ + texlive-lang-greek texlive-lang-cyrillic texlive-lang-european \ + texlive-luatex texlive-latex-extra texlive-latex-recommended \ + texlive-science texlive-fonts-extra \ + fontconfig \ # For optional LibreOffice support (uncomment if needed) # libreoffice-writer \ - && \ - pip install -e ".[workers]" && \ - apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* + && apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* +RUN pip install -e ".[workers]" # Create and copy entrypoint script COPY docker/workers/entry.sh ./entry.sh +COPY docker/workers/unnest-table.lua ./unnest-table.lua RUN chmod +x entry.sh RUN mkdir -p /app/memory_files # Create user and set permissions -RUN useradd -m kb && chown -R kb /app +RUN useradd -m kb +RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \ + chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app + USER kb # Default queues to process diff --git a/docker/workers/unnest-table.lua b/docker/workers/unnest-table.lua new file mode 100644 index 0000000..5420643 --- /dev/null +++ b/docker/workers/unnest-table.lua @@ -0,0 +1,28 @@ +local function rows(tbl) + local r = pandoc.List() + r:extend(tbl.head.rows) + for _,b in ipairs(tbl.bodies) do r:extend(b.body) end + r:extend(tbl.foot.rows) + return r + end + + function Table(t) + local newHead = pandoc.TableHead() + for i,row in ipairs(t.head.rows) do + for j,cell in ipairs(row.cells) do + local inner = cell.contents[1] + if inner and inner.t == 'Table' then + local ins = rows(inner) + local first = table.remove(ins,1) + row.cells[j] = first.cells[1] + newHead.rows:insert(row) + newHead.rows:extend(ins) + goto continue + end + end + newHead.rows:insert(row) + ::continue:: + end + t.head = newHead + return t + end \ No newline at end of file diff --git a/requirements-api.txt b/requirements-api.txt index e7884d4..7d03ed0 100644 --- a/requirements-api.txt +++ b/requirements-api.txt @@ -1,4 +1,4 @@ fastapi==0.112.2 uvicorn==0.29.0 python-jose==3.3.0 -python-multipart==0.0.9 \ No newline at end of file +python-multipart==0.0.9 diff --git a/requirements-mcp.txt b/requirements-mcp.txt new file mode 100644 index 0000000..236903d --- /dev/null +++ b/requirements-mcp.txt @@ -0,0 +1,2 @@ +mcp==1.7.1 +httpx==0.25.1 diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 7c286b7..f40b194 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -1,23 +1,50 @@ """ FastAPI application for the knowledge base. """ -from fastapi import FastAPI, Depends, HTTPException -from sqlalchemy.orm import Session -from memory.common.db import get_scoped_session -from memory.common.db.models import SourceItem +import base64 +import io +from collections import defaultdict +import pathlib +from typing import Annotated, List, Optional, Callable +from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form +from fastapi.responses import FileResponse +import qdrant_client +from qdrant_client.http import models as qdrant_models +from PIL import Image +from pydantic import BaseModel +from memory.common import embedding, qdrant, extract, settings +from memory.common.embedding import get_modality +from memory.common.db.connection import make_session +from memory.common.db.models import Chunk, SourceItem + +import logging + +logger = logging.getLogger(__name__) app = FastAPI(title="Knowledge Base API") -def get_db(): - """Database session dependency""" - db = get_scoped_session() - try: - yield db - finally: - db.close() +class AnnotatedChunk(BaseModel): + id: str + score: float + metadata: dict + preview: Optional[str | None] = None + + +class SearchResponse(BaseModel): + collection: str + results: List[dict] + + +class SearchResult(BaseModel): + id: int + size: int + mime_type: str + chunks: list[AnnotatedChunk] + content: Optional[str] = None + filename: Optional[str] = None @app.get("/health") @@ -26,25 +53,195 @@ def health_check(): return {"status": "healthy"} -@app.get("/sources") -def list_sources( - tag: str = None, - limit: int = 100, - db: Session = Depends(get_db) +def annotated_chunk( + chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool +) -> tuple[SourceItem, AnnotatedChunk]: + def serialize_item(item: bytes | str | Image.Image) -> str | None: + if not previews and not isinstance(item, str): + return None + + if isinstance(item, Image.Image): + buffer = io.BytesIO() + format = item.format or "PNG" + item.save(buffer, format=format) + mime_type = f"image/{format.lower()}" + return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" + elif isinstance(item, bytes): + return base64.b64encode(item).decode("utf-8") + elif isinstance(item, str): + return item + else: + raise ValueError(f"Unsupported item type: {type(item)}") + + metadata = search_result.payload or {} + metadata = { + k: v + for k, v in metadata.items() + if k not in ["content", "filename", "size", "content_type", "tags"] + } + return chunk.source, AnnotatedChunk( + id=str(chunk.id), + score=search_result.score, + metadata=metadata, + preview=serialize_item(chunk.data[0]) if chunk.data else None, + ) + + +def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]: + items = defaultdict(list) + for source, chunk in chunks: + items[source].append(chunk) + + return [ + SearchResult( + id=source.id, + size=source.size, + mime_type=source.mime_type, + filename=source.filename + and source.filename.replace( + str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files" + ), + content=source.content, + chunks=sorted(chunks, key=lambda x: x.score, reverse=True), + ) + for source, chunks in items.items() + ] + + +def query_chunks( + client: qdrant_client.QdrantClient, + upload_data: list[tuple[str, list[extract.Page]]], + allowed_modalities: set[str], + embedder: Callable, + min_score: float = 0.0, + limit: int = 10, +) -> dict[str, list[qdrant_models.ScoredPoint]]: + if not upload_data: + return {} + + chunks = [ + chunk + for content_type, pages in upload_data + if get_modality(content_type) in allowed_modalities + for page in pages + for chunk in page["contents"] + ] + + if not chunks: + return {} + + vector = embedder(chunks, input_type="query")[0] + + return { + collection: [ + r + for r in qdrant.search_vectors( + client=client, + collection_name=collection, + query_vector=vector, + limit=limit, + ) + if r.score >= min_score + ] + for collection in embedding.DEFAULT_COLLECTIONS + } + + +async def input_type(item: str | UploadFile) -> tuple[str, list[extract.Page]]: + if not item: + return "text/plain", [] + + if isinstance(item, str): + return "text/plain", extract.extract_text(item) + content_type = item.content_type or "application/octet-stream" + return content_type, extract.extract_content(content_type, await item.read()) + + +@app.post("/search", response_model=list[SearchResult]) +async def search( + query: Optional[str] = Form(None), + previews: Optional[bool] = Form(False), + modalities: Annotated[list[str], Query()] = [], + files: list[UploadFile] = File([]), + limit: int = Query(10, ge=1, le=100), + min_text_score: float = Query(0.5, ge=0.0, le=1.0), + min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0), ): - """List source items, optionally filtered by tag""" - query = db.query(SourceItem) - - if tag: - query = query.filter(SourceItem.tags.contains([tag])) - - return query.limit(limit).all() + """ + Search across knowledge base using text query and optional files. + + Parameters: + - query: Optional text search query + - modalities: List of modalities to search in (e.g., "text", "photo", "doc") + - files: Optional files to include in the search context + - limit: Maximum number of results per modality + + Returns: + - List of search results sorted by score + """ + upload_data = [await input_type(item) for item in [query, *files]] + logger.error( + f"Querying chunks for {modalities}, query: {query}, previews: {previews}" + ) + + client = qdrant.get_qdrant_client() + allowed_modalities = set(modalities or embedding.DEFAULT_COLLECTIONS.keys()) + text_results = query_chunks( + client, + upload_data, + allowed_modalities & embedding.TEXT_COLLECTIONS, + embedding.embed_text, + min_score=min_text_score, + limit=limit, + ) + multimodal_results = query_chunks( + client, + upload_data, + allowed_modalities, + embedding.embed_mixed, + min_score=min_multimodal_score, + limit=limit, + ) + search_results = { + k: text_results.get(k, []) + multimodal_results.get(k, []) + for k in allowed_modalities + } + + found_chunks = { + str(r.id): r for results in search_results.values() for r in results + } + with make_session() as db: + chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all() + + results = group_chunks( + [ + annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False) + for chunk in chunks + ] + ) + return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True) -@app.get("/sources/{source_id}") -def get_source(source_id: int, db: Session = Depends(get_db)): - """Get a specific source by ID""" - source = db.query(SourceItem).filter(SourceItem.id == source_id).first() - if not source: - raise HTTPException(status_code=404, detail="Source not found") - return source \ No newline at end of file +@app.get("/files/{path:path}") +def get_file_by_path(path: str): + """ + Fetch a file by its path + + Parameters: + - path: Path of the file to fetch (relative to FILE_STORAGE_DIR) + + Returns: + - The file as a download + """ + # Sanitize the path to prevent directory traversal + sanitized_path = path.lstrip("/") + if ".." in sanitized_path: + raise HTTPException(status_code=400, detail="Invalid path") + + file_path = pathlib.Path(settings.FILE_STORAGE_DIR) / sanitized_path + + # Check if the file exists on disk + if not file_path.exists() or not file_path.is_file(): + raise HTTPException(status_code=404, detail=f"File not found at path: {path}") + + return FileResponse(path=file_path, filename=file_path.name) diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 12cbe2d..b73df0d 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -113,7 +113,7 @@ class Chunk(Base): if self.file_path is None: return [self.content] - path = pathlib.Path(self.file_path) + path = pathlib.Path(self.file_path.replace("/app/", "")) if self.file_path.endswith("*"): files = list(path.parent.glob(path.name)) else: @@ -122,7 +122,8 @@ class Chunk(Base): items = [] for file_path in files: if file_path.suffix == ".png": - items.append(Image.open(file_path)) + if file_path.exists(): + items.append(Image.open(file_path)) elif file_path.suffix == ".bin": items.append(file_path.read_bytes()) else: diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 97f5fe7..e575f0a 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,7 +1,7 @@ import logging import pathlib import uuid -from typing import Any, Iterable, Literal, TypedDict +from typing import Any, Iterable, Literal, NotRequired, TypedDict import voyageai from PIL import Image @@ -27,20 +27,63 @@ Embedding = tuple[str, Vector, dict[str, Any]] class Collection(TypedDict): dimension: int distance: DistanceType - on_disk: bool - shards: int + model: str + on_disk: NotRequired[bool] + shards: NotRequired[int] DEFAULT_COLLECTIONS: dict[str, Collection] = { - "mail": {"dimension": 1024, "distance": "Cosine"}, - "chat": {"dimension": 1024, "distance": "Cosine"}, - "git": {"dimension": 1024, "distance": "Cosine"}, - "book": {"dimension": 1024, "distance": "Cosine"}, - "blog": {"dimension": 1024, "distance": "Cosine"}, - "text": {"dimension": 1024, "distance": "Cosine"}, + "mail": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "chat": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "git": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "book": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "blog": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "text": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, # Multimodal - "photo": {"dimension": 1024, "distance": "Cosine"}, - "doc": {"dimension": 1024, "distance": "Cosine"}, + "photo": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, + "doc": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, +} +TEXT_COLLECTIONS = { + coll + for coll, params in DEFAULT_COLLECTIONS.items() + if params["model"] == settings.TEXT_EMBEDDING_MODEL +} +MULTIMODAL_COLLECTIONS = { + coll + for coll, params in DEFAULT_COLLECTIONS.items() + if params["model"] == settings.MIXED_EMBEDDING_MODEL } TYPES = { @@ -69,22 +112,27 @@ def get_modality(mime_type: str) -> str: def embed_chunks( - chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL + chunks: list[extract.MulitmodalChunk], + model: str = settings.TEXT_EMBEDDING_MODEL, + input_type: Literal["document", "query"] = "document", ) -> list[Vector]: vo = voyageai.Client() if model == settings.MIXED_EMBEDDING_MODEL: return vo.multimodal_embed( - chunks, model=model, input_type="document" + chunks, model=model, input_type=input_type ).embeddings - return vo.embed(chunks, model=model, input_type="document").embeddings + return vo.embed(chunks, model=model, input_type=input_type).embeddings def embed_text( - texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL + texts: list[str], + model: str = settings.TEXT_EMBEDDING_MODEL, + input_type: Literal["document", "query"] = "document", ) -> list[Vector]: chunks = [ c for text in texts + if isinstance(text, str) for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) if c.strip() ] @@ -92,7 +140,7 @@ def embed_text( return [] try: - return embed_chunks(chunks, model) + return embed_chunks(chunks, model, input_type) except voyageai.error.InvalidRequestError as e: logger.error(f"Error embedding text: {e}") logger.debug(f"Text: {texts}") @@ -106,7 +154,9 @@ def embed_file( def embed_mixed( - items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL + items: list[extract.MulitmodalChunk], + model: str = settings.MIXED_EMBEDDING_MODEL, + input_type: Literal["document", "query"] = "document", ) -> list[Vector]: def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]: if isinstance(item, str): @@ -116,7 +166,7 @@ def embed_mixed( return [item] chunks = [c for item in items for c in to_chunks(item)] - return embed_chunks([chunks], model) + return embed_chunks([chunks], model, input_type) def embed_page(page: dict[str, Any]) -> list[Vector]: diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 514244d..5e4cf1c 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -7,6 +7,9 @@ import pymupdf # PyMuPDF from PIL import Image from typing import Any, TypedDict, Generator, Sequence +import logging + +logger = logging.getLogger(__name__) MulitmodalChunk = Image.Image | str @@ -57,15 +60,32 @@ def docx_to_pdf( if output_path is None: output_path = docx_path.with_suffix(".pdf") - pypandoc.convert_file(str(docx_path), "pdf", outputfile=str(output_path)) - - return output_path + # Now that we have all packages installed, try xelatex first as it has better Unicode support + try: + logger.info(f"Converting {docx_path} to PDF using xelatex") + pypandoc.convert_file( + str(docx_path), + format="docx", + to="pdf", + outputfile=str(output_path), + extra_args=[ + "--pdf-engine=xelatex", + "--variable=geometry:margin=1in", + "--lua-filter=/app/unnest-table.lua", + ], + ) + logger.info(f"Successfully converted {docx_path} to PDF") + return output_path + except Exception as e: + logger.error(f"Error converting document to PDF: {e}") + raise -def extract_docx(docx_path: pathlib.Path) -> list[Page]: +def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]: """Extract content from DOCX by converting to PDF first, then processing""" with as_file(docx_path) as file_path: pdf_path = docx_to_pdf(file_path) + logger.info(f"Extracted PDF from {file_path}") return doc_to_images(pdf_path) @@ -89,13 +109,17 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]: def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]: + logger.info(f"Extracting content from {mime_type}") if mime_type == "application/pdf": return doc_to_images(content) - if isinstance(content, pathlib.Path) and mime_type in [ + if mime_type in [ "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/msword", ]: - return extract_docx(content) + logger.info(f"Extracting content from {content}") + pages = extract_docx(content) + logger.info(f"Extracted {len(pages)} pages from {content}") + return pages if mime_type.startswith("text/"): return extract_text(content) if mime_type.startswith("image/"): diff --git a/src/memory/common/parsers/email.py b/src/memory/common/parsers/email.py index 285c779..99313f6 100644 --- a/src/memory/common/parsers/email.py +++ b/src/memory/common/parsers/email.py @@ -25,6 +25,7 @@ class EmailMessage(TypedDict): sent_at: datetime | None body: str attachments: list[Attachment] + hash: bytes RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes] @@ -33,10 +34,10 @@ RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes] def extract_recipients(msg: email.message.Message) -> list[str]: """ Extract email recipients from message headers. - + Args: msg: Email message object - + Returns: List of recipient email addresses """ @@ -52,10 +53,10 @@ def extract_recipients(msg: email.message.Message) -> list[str]: def extract_date(msg: email.message.Message) -> datetime | None: """ Parse date from email header. - + Args: msg: Email message object - + Returns: Parsed datetime or None if parsing failed """ @@ -70,18 +71,18 @@ def extract_date(msg: email.message.Message) -> datetime | None: def extract_body(msg: email.message.Message) -> str: """ Extract plain text body from email message. - + Args: msg: Email message object - + Returns: Plain text body content """ body = "" - + if not msg.is_multipart(): try: - return msg.get_payload(decode=True).decode(errors='replace') + return msg.get_payload(decode=True).decode(errors="replace") except Exception as e: logger.error(f"Error decoding message body: {str(e)}") return "" @@ -89,10 +90,10 @@ def extract_body(msg: email.message.Message) -> str: for part in msg.walk(): content_type = part.get_content_type() content_disposition = str(part.get("Content-Disposition", "")) - + if content_type == "text/plain" and "attachment" not in content_disposition: try: - body += part.get_payload(decode=True).decode(errors='replace') + "\n" + body += part.get_payload(decode=True).decode(errors="replace") + "\n" except Exception as e: logger.error(f"Error decoding message part: {str(e)}") return body @@ -101,10 +102,10 @@ def extract_body(msg: email.message.Message) -> str: def extract_attachments(msg: email.message.Message) -> list[Attachment]: """ Extract attachment metadata and content from email. - + Args: msg: Email message object - + Returns: List of attachment dictionaries with metadata and content """ @@ -120,14 +121,18 @@ def extract_attachments(msg: email.message.Message) -> list[Attachment]: if filename := part.get_filename(): try: content = part.get_payload(decode=True) - attachments.append({ - "filename": filename, - "content_type": part.get_content_type(), - "size": len(content), - "content": content - }) + attachments.append( + { + "filename": filename, + "content_type": part.get_content_type(), + "size": len(content), + "content": content, + } + ) except Exception as e: - logger.error(f"Error extracting attachment content for {filename}: {str(e)}") + logger.error( + f"Error extracting attachment content for {filename}: {str(e)}" + ) return attachments @@ -135,13 +140,13 @@ def extract_attachments(msg: email.message.Message) -> list[Attachment]: def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes: """ Compute a SHA-256 hash of message content. - + Args: msg_id: Message ID subject: Email subject sender: Sender email body: Message body - + Returns: SHA-256 hash as bytes """ @@ -152,10 +157,10 @@ def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> b def parse_email_message(raw_email: str, message_id: str) -> EmailMessage: """ Parse raw email into structured data. - + Args: raw_email: Raw email content as string - + Returns: Dict with parsed email data """ @@ -164,7 +169,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage: subject = msg.get("Subject", "") from_ = msg.get("From", "") body = extract_body(msg) - + return EmailMessage( message_id=message_id, subject=subject, @@ -173,5 +178,5 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage: sent_at=extract_date(msg), body=body, attachments=extract_attachments(msg), - hash=compute_message_hash(message_id, subject, from_, body) + hash=compute_message_hash(message_id, subject, from_, body), ) diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py index 6adcb2e..282f5f1 100644 --- a/src/memory/common/qdrant.py +++ b/src/memory/common/qdrant.py @@ -5,15 +5,22 @@ import qdrant_client from qdrant_client.http import models as qdrant_models from qdrant_client.http.exceptions import UnexpectedResponse from memory.common import settings -from memory.common.embedding import Collection, DEFAULT_COLLECTIONS, DistanceType, Vector +from memory.common.embedding import ( + Collection, + DEFAULT_COLLECTIONS, + DistanceType, + Vector, +) logger = logging.getLogger(__name__) def get_qdrant_client() -> qdrant_client.QdrantClient: """Create and return a Qdrant client using environment configuration.""" - logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}") - + logger.info( + f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}" + ) + return qdrant_client.QdrantClient( host=settings.QDRANT_HOST, port=settings.QDRANT_PORT, @@ -34,7 +41,7 @@ def ensure_collection_exists( ) -> bool: """ Ensure a collection exists with the specified parameters. - + Args: client: Qdrant client collection_name: Name of the collection @@ -42,7 +49,7 @@ def ensure_collection_exists( distance: Distance metric (Cosine, Dot, Euclidean) on_disk: Whether to store vectors on disk shards: Number of shards for the collection - + Returns: True if the collection was created, False if it already existed """ @@ -61,21 +68,23 @@ def ensure_collection_exists( on_disk_payload=on_disk, shard_number=shards, ) - + # Create common payload indexes client.create_payload_index( collection_name=collection_name, field_name="tags", field_schema=qdrant_models.PayloadSchemaType.KEYWORD, ) - + return True -def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None: +def initialize_collections( + client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None +) -> None: """ Initialize all required collections in Qdrant. - + Args: client: Qdrant client collections: Dictionary mapping collection names to their parameters. @@ -83,7 +92,7 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict """ if collections is None: collections = DEFAULT_COLLECTIONS - + logger.info(f"Initializing collections:") for name, params in collections.items(): logger.info(f" - {name}") @@ -99,7 +108,7 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict def setup_qdrant() -> qdrant_client.QdrantClient: """Get a Qdrant client and initialize collections. - + Returns: Configured Qdrant client """ @@ -109,14 +118,14 @@ def setup_qdrant() -> qdrant_client.QdrantClient: def upsert_vectors( - client: qdrant_client.QdrantClient, - collection_name: str, + client: qdrant_client.QdrantClient, + collection_name: str, ids: list[str], vectors: list[Vector], payloads: list[dict[str, Any]] = None, ) -> None: """Upsert vectors into a collection. - + Args: client: Qdrant client collection_name: Name of the collection @@ -126,7 +135,7 @@ def upsert_vectors( """ if payloads is None: payloads = [{} for _ in ids] - + points = [ qdrant_models.PointStruct( id=id_str, @@ -135,12 +144,12 @@ def upsert_vectors( ) for id_str, vector, payload in zip(ids, vectors, payloads) ] - + client.upsert( collection_name=collection_name, points=points, ) - + logger.debug(f"Upserted {len(ids)} vectors into {collection_name}") @@ -152,21 +161,21 @@ def search_vectors( limit: int = 10, ) -> list[qdrant_models.ScoredPoint]: """Search for similar vectors in a collection. - + Args: client: Qdrant client collection_name: Name of the collection query_vector: Query vector filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}}) limit: Maximum number of results to return - + Returns: List of scored points """ filter_obj = None if filter_params: filter_obj = qdrant_models.Filter(**filter_params) - + return client.search( collection_name=collection_name, query_vector=query_vector, @@ -182,7 +191,7 @@ def delete_vectors( ) -> None: """ Delete vectors from a collection. - + Args: client: Qdrant client collection_name: Name of the collection @@ -194,18 +203,20 @@ def delete_vectors( points=ids, ), ) - + logger.debug(f"Deleted {len(ids)} vectors from {collection_name}") -def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict: +def get_collection_info( + client: qdrant_client.QdrantClient, collection_name: str +) -> dict: """ Get information about a collection. - + Args: client: Qdrant client collection_name: Name of the collection - + Returns: Dictionary with collection information """ diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 0fa634c..b1d3715 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -52,7 +52,6 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60")) # Worker settings EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600)) -EMAIL_SYNC_INTERVAL = 60 # Embedding settings diff --git a/src/memory/mcp/server.py b/src/memory/mcp/server.py new file mode 100644 index 0000000..001dc9d --- /dev/null +++ b/src/memory/mcp/server.py @@ -0,0 +1,58 @@ +import argparse +import logging +from typing import Any +from fastapi import UploadFile +import httpx +from mcp.server.fastmcp import FastMCP + +SERVER = "http://localhost:8000" + + +logger = logging.getLogger(__name__) +mcp = FastMCP("memory") + + +async def make_request( + path: str, + method: str, + data: dict | None = None, + json: dict | None = None, + files: list[UploadFile] | None = None, +) -> httpx.Response: + async with httpx.AsyncClient() as client: + return await client.request( + method, f"{SERVER}/{path}", data=data, json=json, files=files + ) + + +async def post_data(path: str, data: dict | None = None) -> httpx.Response: + return await make_request(path, "POST", data=data) + + +@mcp.tool() +async def search( + query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10 +) -> list[dict[str, Any]]: + logger.error(f"Searching for {query}") + resp = await post_data( + "search", + { + "query": query, + "previews": previews, + "modalities": modalities, + "limit": limit, + }, + ) + + return resp.json() + + +if __name__ == "__main__": + # Initialize and run the server + args = argparse.ArgumentParser() + args.add_argument("--server", type=str) + args = args.parse_args() + + SERVER = args.server + + mcp.run(transport=args.transport) diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index efbc615..adcc133 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -124,13 +124,14 @@ def create_mail_message( folder=folder, ) - if parsed_email["attachments"]: - mail_message.attachments = process_attachments( - parsed_email["attachments"], mail_message - ) - db_session.add(mail_message) + if parsed_email["attachments"]: + attachments = process_attachments(parsed_email["attachments"], mail_message) + db_session.add_all(attachments) + mail_message.attachments = attachments + + db_session.add(mail_message) return mail_message @@ -171,13 +172,11 @@ def check_message_exists( account = db.query(EmailAccount).get(account_id) if not account: logger.error(f"Account {account_id} not found") - return None + return False parsed_email = parse_email_message(raw_email, message_id) - - # Use server-provided message ID if missing - if not parsed_email["message_id"]: - parsed_email["message_id"] = f"generated-{message_id}" + if "szczepalins" in raw_email.lower(): + print(parsed_email["message_id"]) return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"]) diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index cbb8a85..37a84b4 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -58,6 +58,7 @@ def process_message( db, account.tags, folder, raw_email, message_id ) + db.flush() vectorize_email(mail_message) db.commit() @@ -65,12 +66,16 @@ def process_message( logger.info("Chunks:") for chunk in mail_message.chunks: logger.info(f" - {chunk.id}") + for attachment in mail_message.attachments: + logger.info(f" - Attachment {attachment.id}") + for chunk in attachment.chunks: + logger.info(f" - {chunk.id}") return mail_message.id @app.task(name=SYNC_ACCOUNT) -def sync_account(account_id: int) -> dict: +def sync_account(account_id: int, since_date: str | None = None) -> dict: """ Synchronize emails from a specific account. @@ -80,7 +85,7 @@ def sync_account(account_id: int) -> dict: Returns: dict with stats about the sync operation """ - logger.info(f"Syncing account {account_id}") + logger.info(f"Syncing account {account_id} since {since_date}") with make_session() as db: account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() @@ -88,8 +93,11 @@ def sync_account(account_id: int) -> dict: logger.warning(f"Account {account_id} not found or inactive") return {"error": "Account not found or inactive"} - folders_to_process = account.folders or ["INBOX"] - since_date = account.last_sync_at or datetime(1970, 1, 1) + folders_to_process: list[str] = account.folders or ["INBOX"] + if since_date: + cutoff_date = datetime.fromisoformat(since_date) + else: + cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1) messages_found = 0 new_messages = 0 @@ -106,7 +114,7 @@ def sync_account(account_id: int) -> dict: with imap_connection(account) as conn: for folder in folders_to_process: folder_stats = process_folder( - conn, folder, account, since_date, process_message_wrapper + conn, folder, account, cutoff_date, process_message_wrapper ) messages_found += folder_stats["messages_found"] @@ -121,6 +129,7 @@ def sync_account(account_id: int) -> dict: return { "account": account.email_address, + "since_date": cutoff_date.isoformat(), "folders_processed": len(folders_to_process), "messages_found": messages_found, "new_messages": new_messages,