initial server

This commit is contained in:
Daniel O'Connell 2025-05-04 00:31:13 +02:00
parent fe15442a6d
commit cf98d38bb8
14 changed files with 518 additions and 128 deletions

View File

@ -10,21 +10,28 @@ COPY src/ ./src/
# Install dependencies # Install dependencies
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
libpq-dev gcc pandoc \ libpq-dev gcc pandoc \
texlive-full texlive-fonts-recommended texlive-plain-generic \ texlive-xetex texlive-fonts-recommended texlive-plain-generic \
texlive-lang-greek texlive-lang-cyrillic texlive-lang-european \
texlive-luatex texlive-latex-extra texlive-latex-recommended \
texlive-science texlive-fonts-extra \
fontconfig \
# For optional LibreOffice support (uncomment if needed) # For optional LibreOffice support (uncomment if needed)
# libreoffice-writer \ # libreoffice-writer \
&& \ && apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
pip install -e ".[workers]" && \ RUN pip install -e ".[workers]"
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
# Create and copy entrypoint script # Create and copy entrypoint script
COPY docker/workers/entry.sh ./entry.sh COPY docker/workers/entry.sh ./entry.sh
COPY docker/workers/unnest-table.lua ./unnest-table.lua
RUN chmod +x entry.sh RUN chmod +x entry.sh
RUN mkdir -p /app/memory_files RUN mkdir -p /app/memory_files
# Create user and set permissions # Create user and set permissions
RUN useradd -m kb && chown -R kb /app RUN useradd -m kb
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
USER kb USER kb
# Default queues to process # Default queues to process

View File

@ -0,0 +1,28 @@
local function rows(tbl)
local r = pandoc.List()
r:extend(tbl.head.rows)
for _,b in ipairs(tbl.bodies) do r:extend(b.body) end
r:extend(tbl.foot.rows)
return r
end
function Table(t)
local newHead = pandoc.TableHead()
for i,row in ipairs(t.head.rows) do
for j,cell in ipairs(row.cells) do
local inner = cell.contents[1]
if inner and inner.t == 'Table' then
local ins = rows(inner)
local first = table.remove(ins,1)
row.cells[j] = first.cells[1]
newHead.rows:insert(row)
newHead.rows:extend(ins)
goto continue
end
end
newHead.rows:insert(row)
::continue::
end
t.head = newHead
return t
end

View File

@ -1,4 +1,4 @@
fastapi==0.112.2 fastapi==0.112.2
uvicorn==0.29.0 uvicorn==0.29.0
python-jose==3.3.0 python-jose==3.3.0
python-multipart==0.0.9 python-multipart==0.0.9

2
requirements-mcp.txt Normal file
View File

@ -0,0 +1,2 @@
mcp==1.7.1
httpx==0.25.1

View File

@ -1,23 +1,50 @@
""" """
FastAPI application for the knowledge base. FastAPI application for the knowledge base.
""" """
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from memory.common.db import get_scoped_session import base64
from memory.common.db.models import SourceItem import io
from collections import defaultdict
import pathlib
from typing import Annotated, List, Optional, Callable
from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form
from fastapi.responses import FileResponse
import qdrant_client
from qdrant_client.http import models as qdrant_models
from PIL import Image
from pydantic import BaseModel
from memory.common import embedding, qdrant, extract, settings
from memory.common.embedding import get_modality
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
import logging
logger = logging.getLogger(__name__)
app = FastAPI(title="Knowledge Base API") app = FastAPI(title="Knowledge Base API")
def get_db(): class AnnotatedChunk(BaseModel):
"""Database session dependency""" id: str
db = get_scoped_session() score: float
try: metadata: dict
yield db preview: Optional[str | None] = None
finally:
db.close()
class SearchResponse(BaseModel):
collection: str
results: List[dict]
class SearchResult(BaseModel):
id: int
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str] = None
filename: Optional[str] = None
@app.get("/health") @app.get("/health")
@ -26,25 +53,195 @@ def health_check():
return {"status": "healthy"} return {"status": "healthy"}
@app.get("/sources") def annotated_chunk(
def list_sources( chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
tag: str = None, ) -> tuple[SourceItem, AnnotatedChunk]:
limit: int = 100, def serialize_item(item: bytes | str | Image.Image) -> str | None:
db: Session = Depends(get_db) if not previews and not isinstance(item, str):
return None
if isinstance(item, Image.Image):
buffer = io.BytesIO()
format = item.format or "PNG"
item.save(buffer, format=format)
mime_type = f"image/{format.lower()}"
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
elif isinstance(item, bytes):
return base64.b64encode(item).decode("utf-8")
elif isinstance(item, str):
return item
else:
raise ValueError(f"Unsupported item type: {type(item)}")
metadata = search_result.payload or {}
metadata = {
k: v
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
return chunk.source, AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
preview=serialize_item(chunk.data[0]) if chunk.data else None,
)
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
items = defaultdict(list)
for source, chunk in chunks:
items[source].append(chunk)
return [
SearchResult(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=source.content,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
)
for source, chunks in items.items()
]
def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[tuple[str, list[extract.Page]]],
allowed_modalities: set[str],
embedder: Callable,
min_score: float = 0.0,
limit: int = 10,
) -> dict[str, list[qdrant_models.ScoredPoint]]:
if not upload_data:
return {}
chunks = [
chunk
for content_type, pages in upload_data
if get_modality(content_type) in allowed_modalities
for page in pages
for chunk in page["contents"]
]
if not chunks:
return {}
vector = embedder(chunks, input_type="query")[0]
return {
collection: [
r
for r in qdrant.search_vectors(
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
)
if r.score >= min_score
]
for collection in embedding.DEFAULT_COLLECTIONS
}
async def input_type(item: str | UploadFile) -> tuple[str, list[extract.Page]]:
if not item:
return "text/plain", []
if isinstance(item, str):
return "text/plain", extract.extract_text(item)
content_type = item.content_type or "application/octet-stream"
return content_type, extract.extract_content(content_type, await item.read())
@app.post("/search", response_model=list[SearchResult])
async def search(
query: Optional[str] = Form(None),
previews: Optional[bool] = Form(False),
modalities: Annotated[list[str], Query()] = [],
files: list[UploadFile] = File([]),
limit: int = Query(10, ge=1, le=100),
min_text_score: float = Query(0.5, ge=0.0, le=1.0),
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
): ):
"""List source items, optionally filtered by tag""" """
query = db.query(SourceItem) Search across knowledge base using text query and optional files.
if tag: Parameters:
query = query.filter(SourceItem.tags.contains([tag])) - query: Optional text search query
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
return query.limit(limit).all() - files: Optional files to include in the search context
- limit: Maximum number of results per modality
Returns:
- List of search results sorted by score
"""
upload_data = [await input_type(item) for item in [query, *files]]
logger.error(
f"Querying chunks for {modalities}, query: {query}, previews: {previews}"
)
client = qdrant.get_qdrant_client()
allowed_modalities = set(modalities or embedding.DEFAULT_COLLECTIONS.keys())
text_results = query_chunks(
client,
upload_data,
allowed_modalities & embedding.TEXT_COLLECTIONS,
embedding.embed_text,
min_score=min_text_score,
limit=limit,
)
multimodal_results = query_chunks(
client,
upload_data,
allowed_modalities,
embedding.embed_mixed,
min_score=min_multimodal_score,
limit=limit,
)
search_results = {
k: text_results.get(k, []) + multimodal_results.get(k, [])
for k in allowed_modalities
}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
results = group_chunks(
[
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
@app.get("/sources/{source_id}") @app.get("/files/{path:path}")
def get_source(source_id: int, db: Session = Depends(get_db)): def get_file_by_path(path: str):
"""Get a specific source by ID""" """
source = db.query(SourceItem).filter(SourceItem.id == source_id).first() Fetch a file by its path
if not source:
raise HTTPException(status_code=404, detail="Source not found") Parameters:
return source - path: Path of the file to fetch (relative to FILE_STORAGE_DIR)
Returns:
- The file as a download
"""
# Sanitize the path to prevent directory traversal
sanitized_path = path.lstrip("/")
if ".." in sanitized_path:
raise HTTPException(status_code=400, detail="Invalid path")
file_path = pathlib.Path(settings.FILE_STORAGE_DIR) / sanitized_path
# Check if the file exists on disk
if not file_path.exists() or not file_path.is_file():
raise HTTPException(status_code=404, detail=f"File not found at path: {path}")
return FileResponse(path=file_path, filename=file_path.name)

View File

@ -113,7 +113,7 @@ class Chunk(Base):
if self.file_path is None: if self.file_path is None:
return [self.content] return [self.content]
path = pathlib.Path(self.file_path) path = pathlib.Path(self.file_path.replace("/app/", ""))
if self.file_path.endswith("*"): if self.file_path.endswith("*"):
files = list(path.parent.glob(path.name)) files = list(path.parent.glob(path.name))
else: else:
@ -122,7 +122,8 @@ class Chunk(Base):
items = [] items = []
for file_path in files: for file_path in files:
if file_path.suffix == ".png": if file_path.suffix == ".png":
items.append(Image.open(file_path)) if file_path.exists():
items.append(Image.open(file_path))
elif file_path.suffix == ".bin": elif file_path.suffix == ".bin":
items.append(file_path.read_bytes()) items.append(file_path.read_bytes())
else: else:

View File

@ -1,7 +1,7 @@
import logging import logging
import pathlib import pathlib
import uuid import uuid
from typing import Any, Iterable, Literal, TypedDict from typing import Any, Iterable, Literal, NotRequired, TypedDict
import voyageai import voyageai
from PIL import Image from PIL import Image
@ -27,20 +27,63 @@ Embedding = tuple[str, Vector, dict[str, Any]]
class Collection(TypedDict): class Collection(TypedDict):
dimension: int dimension: int
distance: DistanceType distance: DistanceType
on_disk: bool model: str
shards: int on_disk: NotRequired[bool]
shards: NotRequired[int]
DEFAULT_COLLECTIONS: dict[str, Collection] = { DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1024, "distance": "Cosine"}, "mail": {
"chat": {"dimension": 1024, "distance": "Cosine"}, "dimension": 1024,
"git": {"dimension": 1024, "distance": "Cosine"}, "distance": "Cosine",
"book": {"dimension": 1024, "distance": "Cosine"}, "model": settings.TEXT_EMBEDDING_MODEL,
"blog": {"dimension": 1024, "distance": "Cosine"}, },
"text": {"dimension": 1024, "distance": "Cosine"}, "chat": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"git": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"book": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"blog": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"text": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
# Multimodal # Multimodal
"photo": {"dimension": 1024, "distance": "Cosine"}, "photo": {
"doc": {"dimension": 1024, "distance": "Cosine"}, "dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
"doc": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
}
TEXT_COLLECTIONS = {
coll
for coll, params in DEFAULT_COLLECTIONS.items()
if params["model"] == settings.TEXT_EMBEDDING_MODEL
}
MULTIMODAL_COLLECTIONS = {
coll
for coll, params in DEFAULT_COLLECTIONS.items()
if params["model"] == settings.MIXED_EMBEDDING_MODEL
} }
TYPES = { TYPES = {
@ -69,22 +112,27 @@ def get_modality(mime_type: str) -> str:
def embed_chunks( def embed_chunks(
chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL chunks: list[extract.MulitmodalChunk],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]: ) -> list[Vector]:
vo = voyageai.Client() vo = voyageai.Client()
if model == settings.MIXED_EMBEDDING_MODEL: if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed( return vo.multimodal_embed(
chunks, model=model, input_type="document" chunks, model=model, input_type=input_type
).embeddings ).embeddings
return vo.embed(chunks, model=model, input_type="document").embeddings return vo.embed(chunks, model=model, input_type=input_type).embeddings
def embed_text( def embed_text(
texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL texts: list[str],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]: ) -> list[Vector]:
chunks = [ chunks = [
c c
for text in texts for text in texts
if isinstance(text, str)
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
if c.strip() if c.strip()
] ]
@ -92,7 +140,7 @@ def embed_text(
return [] return []
try: try:
return embed_chunks(chunks, model) return embed_chunks(chunks, model, input_type)
except voyageai.error.InvalidRequestError as e: except voyageai.error.InvalidRequestError as e:
logger.error(f"Error embedding text: {e}") logger.error(f"Error embedding text: {e}")
logger.debug(f"Text: {texts}") logger.debug(f"Text: {texts}")
@ -106,7 +154,9 @@ def embed_file(
def embed_mixed( def embed_mixed(
items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL items: list[extract.MulitmodalChunk],
model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]: ) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]: def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
if isinstance(item, str): if isinstance(item, str):
@ -116,7 +166,7 @@ def embed_mixed(
return [item] return [item]
chunks = [c for item in items for c in to_chunks(item)] chunks = [c for item in items for c in to_chunks(item)]
return embed_chunks([chunks], model) return embed_chunks([chunks], model, input_type)
def embed_page(page: dict[str, Any]) -> list[Vector]: def embed_page(page: dict[str, Any]) -> list[Vector]:

View File

@ -7,6 +7,9 @@ import pymupdf # PyMuPDF
from PIL import Image from PIL import Image
from typing import Any, TypedDict, Generator, Sequence from typing import Any, TypedDict, Generator, Sequence
import logging
logger = logging.getLogger(__name__)
MulitmodalChunk = Image.Image | str MulitmodalChunk = Image.Image | str
@ -57,15 +60,32 @@ def docx_to_pdf(
if output_path is None: if output_path is None:
output_path = docx_path.with_suffix(".pdf") output_path = docx_path.with_suffix(".pdf")
pypandoc.convert_file(str(docx_path), "pdf", outputfile=str(output_path)) # Now that we have all packages installed, try xelatex first as it has better Unicode support
try:
return output_path logger.info(f"Converting {docx_path} to PDF using xelatex")
pypandoc.convert_file(
str(docx_path),
format="docx",
to="pdf",
outputfile=str(output_path),
extra_args=[
"--pdf-engine=xelatex",
"--variable=geometry:margin=1in",
"--lua-filter=/app/unnest-table.lua",
],
)
logger.info(f"Successfully converted {docx_path} to PDF")
return output_path
except Exception as e:
logger.error(f"Error converting document to PDF: {e}")
raise
def extract_docx(docx_path: pathlib.Path) -> list[Page]: def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]:
"""Extract content from DOCX by converting to PDF first, then processing""" """Extract content from DOCX by converting to PDF first, then processing"""
with as_file(docx_path) as file_path: with as_file(docx_path) as file_path:
pdf_path = docx_to_pdf(file_path) pdf_path = docx_to_pdf(file_path)
logger.info(f"Extracted PDF from {file_path}")
return doc_to_images(pdf_path) return doc_to_images(pdf_path)
@ -89,13 +109,17 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]: def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
logger.info(f"Extracting content from {mime_type}")
if mime_type == "application/pdf": if mime_type == "application/pdf":
return doc_to_images(content) return doc_to_images(content)
if isinstance(content, pathlib.Path) and mime_type in [ if mime_type in [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword", "application/msword",
]: ]:
return extract_docx(content) logger.info(f"Extracting content from {content}")
pages = extract_docx(content)
logger.info(f"Extracted {len(pages)} pages from {content}")
return pages
if mime_type.startswith("text/"): if mime_type.startswith("text/"):
return extract_text(content) return extract_text(content)
if mime_type.startswith("image/"): if mime_type.startswith("image/"):

View File

@ -25,6 +25,7 @@ class EmailMessage(TypedDict):
sent_at: datetime | None sent_at: datetime | None
body: str body: str
attachments: list[Attachment] attachments: list[Attachment]
hash: bytes
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes] RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
@ -33,10 +34,10 @@ RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
def extract_recipients(msg: email.message.Message) -> list[str]: def extract_recipients(msg: email.message.Message) -> list[str]:
""" """
Extract email recipients from message headers. Extract email recipients from message headers.
Args: Args:
msg: Email message object msg: Email message object
Returns: Returns:
List of recipient email addresses List of recipient email addresses
""" """
@ -52,10 +53,10 @@ def extract_recipients(msg: email.message.Message) -> list[str]:
def extract_date(msg: email.message.Message) -> datetime | None: def extract_date(msg: email.message.Message) -> datetime | None:
""" """
Parse date from email header. Parse date from email header.
Args: Args:
msg: Email message object msg: Email message object
Returns: Returns:
Parsed datetime or None if parsing failed Parsed datetime or None if parsing failed
""" """
@ -70,18 +71,18 @@ def extract_date(msg: email.message.Message) -> datetime | None:
def extract_body(msg: email.message.Message) -> str: def extract_body(msg: email.message.Message) -> str:
""" """
Extract plain text body from email message. Extract plain text body from email message.
Args: Args:
msg: Email message object msg: Email message object
Returns: Returns:
Plain text body content Plain text body content
""" """
body = "" body = ""
if not msg.is_multipart(): if not msg.is_multipart():
try: try:
return msg.get_payload(decode=True).decode(errors='replace') return msg.get_payload(decode=True).decode(errors="replace")
except Exception as e: except Exception as e:
logger.error(f"Error decoding message body: {str(e)}") logger.error(f"Error decoding message body: {str(e)}")
return "" return ""
@ -89,10 +90,10 @@ def extract_body(msg: email.message.Message) -> str:
for part in msg.walk(): for part in msg.walk():
content_type = part.get_content_type() content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition", "")) content_disposition = str(part.get("Content-Disposition", ""))
if content_type == "text/plain" and "attachment" not in content_disposition: if content_type == "text/plain" and "attachment" not in content_disposition:
try: try:
body += part.get_payload(decode=True).decode(errors='replace') + "\n" body += part.get_payload(decode=True).decode(errors="replace") + "\n"
except Exception as e: except Exception as e:
logger.error(f"Error decoding message part: {str(e)}") logger.error(f"Error decoding message part: {str(e)}")
return body return body
@ -101,10 +102,10 @@ def extract_body(msg: email.message.Message) -> str:
def extract_attachments(msg: email.message.Message) -> list[Attachment]: def extract_attachments(msg: email.message.Message) -> list[Attachment]:
""" """
Extract attachment metadata and content from email. Extract attachment metadata and content from email.
Args: Args:
msg: Email message object msg: Email message object
Returns: Returns:
List of attachment dictionaries with metadata and content List of attachment dictionaries with metadata and content
""" """
@ -120,14 +121,18 @@ def extract_attachments(msg: email.message.Message) -> list[Attachment]:
if filename := part.get_filename(): if filename := part.get_filename():
try: try:
content = part.get_payload(decode=True) content = part.get_payload(decode=True)
attachments.append({ attachments.append(
"filename": filename, {
"content_type": part.get_content_type(), "filename": filename,
"size": len(content), "content_type": part.get_content_type(),
"content": content "size": len(content),
}) "content": content,
}
)
except Exception as e: except Exception as e:
logger.error(f"Error extracting attachment content for {filename}: {str(e)}") logger.error(
f"Error extracting attachment content for {filename}: {str(e)}"
)
return attachments return attachments
@ -135,13 +140,13 @@ def extract_attachments(msg: email.message.Message) -> list[Attachment]:
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes: def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
""" """
Compute a SHA-256 hash of message content. Compute a SHA-256 hash of message content.
Args: Args:
msg_id: Message ID msg_id: Message ID
subject: Email subject subject: Email subject
sender: Sender email sender: Sender email
body: Message body body: Message body
Returns: Returns:
SHA-256 hash as bytes SHA-256 hash as bytes
""" """
@ -152,10 +157,10 @@ def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> b
def parse_email_message(raw_email: str, message_id: str) -> EmailMessage: def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
""" """
Parse raw email into structured data. Parse raw email into structured data.
Args: Args:
raw_email: Raw email content as string raw_email: Raw email content as string
Returns: Returns:
Dict with parsed email data Dict with parsed email data
""" """
@ -164,7 +169,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
subject = msg.get("Subject", "") subject = msg.get("Subject", "")
from_ = msg.get("From", "") from_ = msg.get("From", "")
body = extract_body(msg) body = extract_body(msg)
return EmailMessage( return EmailMessage(
message_id=message_id, message_id=message_id,
subject=subject, subject=subject,
@ -173,5 +178,5 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
sent_at=extract_date(msg), sent_at=extract_date(msg),
body=body, body=body,
attachments=extract_attachments(msg), attachments=extract_attachments(msg),
hash=compute_message_hash(message_id, subject, from_, body) hash=compute_message_hash(message_id, subject, from_, body),
) )

View File

@ -5,15 +5,22 @@ import qdrant_client
from qdrant_client.http import models as qdrant_models from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common import settings from memory.common import settings
from memory.common.embedding import Collection, DEFAULT_COLLECTIONS, DistanceType, Vector from memory.common.embedding import (
Collection,
DEFAULT_COLLECTIONS,
DistanceType,
Vector,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_qdrant_client() -> qdrant_client.QdrantClient: def get_qdrant_client() -> qdrant_client.QdrantClient:
"""Create and return a Qdrant client using environment configuration.""" """Create and return a Qdrant client using environment configuration."""
logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}") logger.info(
f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}"
)
return qdrant_client.QdrantClient( return qdrant_client.QdrantClient(
host=settings.QDRANT_HOST, host=settings.QDRANT_HOST,
port=settings.QDRANT_PORT, port=settings.QDRANT_PORT,
@ -34,7 +41,7 @@ def ensure_collection_exists(
) -> bool: ) -> bool:
""" """
Ensure a collection exists with the specified parameters. Ensure a collection exists with the specified parameters.
Args: Args:
client: Qdrant client client: Qdrant client
collection_name: Name of the collection collection_name: Name of the collection
@ -42,7 +49,7 @@ def ensure_collection_exists(
distance: Distance metric (Cosine, Dot, Euclidean) distance: Distance metric (Cosine, Dot, Euclidean)
on_disk: Whether to store vectors on disk on_disk: Whether to store vectors on disk
shards: Number of shards for the collection shards: Number of shards for the collection
Returns: Returns:
True if the collection was created, False if it already existed True if the collection was created, False if it already existed
""" """
@ -61,21 +68,23 @@ def ensure_collection_exists(
on_disk_payload=on_disk, on_disk_payload=on_disk,
shard_number=shards, shard_number=shards,
) )
# Create common payload indexes # Create common payload indexes
client.create_payload_index( client.create_payload_index(
collection_name=collection_name, collection_name=collection_name,
field_name="tags", field_name="tags",
field_schema=qdrant_models.PayloadSchemaType.KEYWORD, field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
) )
return True return True
def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None: def initialize_collections(
client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None
) -> None:
""" """
Initialize all required collections in Qdrant. Initialize all required collections in Qdrant.
Args: Args:
client: Qdrant client client: Qdrant client
collections: Dictionary mapping collection names to their parameters. collections: Dictionary mapping collection names to their parameters.
@ -83,7 +92,7 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict
""" """
if collections is None: if collections is None:
collections = DEFAULT_COLLECTIONS collections = DEFAULT_COLLECTIONS
logger.info(f"Initializing collections:") logger.info(f"Initializing collections:")
for name, params in collections.items(): for name, params in collections.items():
logger.info(f" - {name}") logger.info(f" - {name}")
@ -99,7 +108,7 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict
def setup_qdrant() -> qdrant_client.QdrantClient: def setup_qdrant() -> qdrant_client.QdrantClient:
"""Get a Qdrant client and initialize collections. """Get a Qdrant client and initialize collections.
Returns: Returns:
Configured Qdrant client Configured Qdrant client
""" """
@ -109,14 +118,14 @@ def setup_qdrant() -> qdrant_client.QdrantClient:
def upsert_vectors( def upsert_vectors(
client: qdrant_client.QdrantClient, client: qdrant_client.QdrantClient,
collection_name: str, collection_name: str,
ids: list[str], ids: list[str],
vectors: list[Vector], vectors: list[Vector],
payloads: list[dict[str, Any]] = None, payloads: list[dict[str, Any]] = None,
) -> None: ) -> None:
"""Upsert vectors into a collection. """Upsert vectors into a collection.
Args: Args:
client: Qdrant client client: Qdrant client
collection_name: Name of the collection collection_name: Name of the collection
@ -126,7 +135,7 @@ def upsert_vectors(
""" """
if payloads is None: if payloads is None:
payloads = [{} for _ in ids] payloads = [{} for _ in ids]
points = [ points = [
qdrant_models.PointStruct( qdrant_models.PointStruct(
id=id_str, id=id_str,
@ -135,12 +144,12 @@ def upsert_vectors(
) )
for id_str, vector, payload in zip(ids, vectors, payloads) for id_str, vector, payload in zip(ids, vectors, payloads)
] ]
client.upsert( client.upsert(
collection_name=collection_name, collection_name=collection_name,
points=points, points=points,
) )
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}") logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
@ -152,21 +161,21 @@ def search_vectors(
limit: int = 10, limit: int = 10,
) -> list[qdrant_models.ScoredPoint]: ) -> list[qdrant_models.ScoredPoint]:
"""Search for similar vectors in a collection. """Search for similar vectors in a collection.
Args: Args:
client: Qdrant client client: Qdrant client
collection_name: Name of the collection collection_name: Name of the collection
query_vector: Query vector query_vector: Query vector
filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}}) filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}})
limit: Maximum number of results to return limit: Maximum number of results to return
Returns: Returns:
List of scored points List of scored points
""" """
filter_obj = None filter_obj = None
if filter_params: if filter_params:
filter_obj = qdrant_models.Filter(**filter_params) filter_obj = qdrant_models.Filter(**filter_params)
return client.search( return client.search(
collection_name=collection_name, collection_name=collection_name,
query_vector=query_vector, query_vector=query_vector,
@ -182,7 +191,7 @@ def delete_vectors(
) -> None: ) -> None:
""" """
Delete vectors from a collection. Delete vectors from a collection.
Args: Args:
client: Qdrant client client: Qdrant client
collection_name: Name of the collection collection_name: Name of the collection
@ -194,18 +203,20 @@ def delete_vectors(
points=ids, points=ids,
), ),
) )
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}") logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict: def get_collection_info(
client: qdrant_client.QdrantClient, collection_name: str
) -> dict:
""" """
Get information about a collection. Get information about a collection.
Args: Args:
client: Qdrant client client: Qdrant client
collection_name: Name of the collection collection_name: Name of the collection
Returns: Returns:
Dictionary with collection information Dictionary with collection information
""" """

View File

@ -52,7 +52,6 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
# Worker settings # Worker settings
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600)) EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
EMAIL_SYNC_INTERVAL = 60
# Embedding settings # Embedding settings

58
src/memory/mcp/server.py Normal file
View File

@ -0,0 +1,58 @@
import argparse
import logging
from typing import Any
from fastapi import UploadFile
import httpx
from mcp.server.fastmcp import FastMCP
SERVER = "http://localhost:8000"
logger = logging.getLogger(__name__)
mcp = FastMCP("memory")
async def make_request(
path: str,
method: str,
data: dict | None = None,
json: dict | None = None,
files: list[UploadFile] | None = None,
) -> httpx.Response:
async with httpx.AsyncClient() as client:
return await client.request(
method, f"{SERVER}/{path}", data=data, json=json, files=files
)
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
return await make_request(path, "POST", data=data)
@mcp.tool()
async def search(
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
) -> list[dict[str, Any]]:
logger.error(f"Searching for {query}")
resp = await post_data(
"search",
{
"query": query,
"previews": previews,
"modalities": modalities,
"limit": limit,
},
)
return resp.json()
if __name__ == "__main__":
# Initialize and run the server
args = argparse.ArgumentParser()
args.add_argument("--server", type=str)
args = args.parse_args()
SERVER = args.server
mcp.run(transport=args.transport)

View File

@ -124,13 +124,14 @@ def create_mail_message(
folder=folder, folder=folder,
) )
if parsed_email["attachments"]:
mail_message.attachments = process_attachments(
parsed_email["attachments"], mail_message
)
db_session.add(mail_message) db_session.add(mail_message)
if parsed_email["attachments"]:
attachments = process_attachments(parsed_email["attachments"], mail_message)
db_session.add_all(attachments)
mail_message.attachments = attachments
db_session.add(mail_message)
return mail_message return mail_message
@ -171,13 +172,11 @@ def check_message_exists(
account = db.query(EmailAccount).get(account_id) account = db.query(EmailAccount).get(account_id)
if not account: if not account:
logger.error(f"Account {account_id} not found") logger.error(f"Account {account_id} not found")
return None return False
parsed_email = parse_email_message(raw_email, message_id) parsed_email = parse_email_message(raw_email, message_id)
if "szczepalins" in raw_email.lower():
# Use server-provided message ID if missing print(parsed_email["message_id"])
if not parsed_email["message_id"]:
parsed_email["message_id"] = f"generated-{message_id}"
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"]) return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])

View File

@ -58,6 +58,7 @@ def process_message(
db, account.tags, folder, raw_email, message_id db, account.tags, folder, raw_email, message_id
) )
db.flush()
vectorize_email(mail_message) vectorize_email(mail_message)
db.commit() db.commit()
@ -65,12 +66,16 @@ def process_message(
logger.info("Chunks:") logger.info("Chunks:")
for chunk in mail_message.chunks: for chunk in mail_message.chunks:
logger.info(f" - {chunk.id}") logger.info(f" - {chunk.id}")
for attachment in mail_message.attachments:
logger.info(f" - Attachment {attachment.id}")
for chunk in attachment.chunks:
logger.info(f" - {chunk.id}")
return mail_message.id return mail_message.id
@app.task(name=SYNC_ACCOUNT) @app.task(name=SYNC_ACCOUNT)
def sync_account(account_id: int) -> dict: def sync_account(account_id: int, since_date: str | None = None) -> dict:
""" """
Synchronize emails from a specific account. Synchronize emails from a specific account.
@ -80,7 +85,7 @@ def sync_account(account_id: int) -> dict:
Returns: Returns:
dict with stats about the sync operation dict with stats about the sync operation
""" """
logger.info(f"Syncing account {account_id}") logger.info(f"Syncing account {account_id} since {since_date}")
with make_session() as db: with make_session() as db:
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
@ -88,8 +93,11 @@ def sync_account(account_id: int) -> dict:
logger.warning(f"Account {account_id} not found or inactive") logger.warning(f"Account {account_id} not found or inactive")
return {"error": "Account not found or inactive"} return {"error": "Account not found or inactive"}
folders_to_process = account.folders or ["INBOX"] folders_to_process: list[str] = account.folders or ["INBOX"]
since_date = account.last_sync_at or datetime(1970, 1, 1) if since_date:
cutoff_date = datetime.fromisoformat(since_date)
else:
cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1)
messages_found = 0 messages_found = 0
new_messages = 0 new_messages = 0
@ -106,7 +114,7 @@ def sync_account(account_id: int) -> dict:
with imap_connection(account) as conn: with imap_connection(account) as conn:
for folder in folders_to_process: for folder in folders_to_process:
folder_stats = process_folder( folder_stats = process_folder(
conn, folder, account, since_date, process_message_wrapper conn, folder, account, cutoff_date, process_message_wrapper
) )
messages_found += folder_stats["messages_found"] messages_found += folder_stats["messages_found"]
@ -121,6 +129,7 @@ def sync_account(account_id: int) -> dict:
return { return {
"account": account.email_address, "account": account.email_address,
"since_date": cutoff_date.isoformat(),
"folders_processed": len(folders_to_process), "folders_processed": len(folders_to_process),
"messages_found": messages_found, "messages_found": messages_found,
"new_messages": new_messages, "new_messages": new_messages,