mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 23:24:43 +02:00
initial server
This commit is contained in:
parent
fe15442a6d
commit
cf98d38bb8
@ -10,21 +10,28 @@ COPY src/ ./src/
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libpq-dev gcc pandoc \
|
||||
texlive-full texlive-fonts-recommended texlive-plain-generic \
|
||||
texlive-xetex texlive-fonts-recommended texlive-plain-generic \
|
||||
texlive-lang-greek texlive-lang-cyrillic texlive-lang-european \
|
||||
texlive-luatex texlive-latex-extra texlive-latex-recommended \
|
||||
texlive-science texlive-fonts-extra \
|
||||
fontconfig \
|
||||
# For optional LibreOffice support (uncomment if needed)
|
||||
# libreoffice-writer \
|
||||
&& \
|
||||
pip install -e ".[workers]" && \
|
||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
&& apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
RUN pip install -e ".[workers]"
|
||||
|
||||
# Create and copy entrypoint script
|
||||
COPY docker/workers/entry.sh ./entry.sh
|
||||
COPY docker/workers/unnest-table.lua ./unnest-table.lua
|
||||
RUN chmod +x entry.sh
|
||||
|
||||
RUN mkdir -p /app/memory_files
|
||||
|
||||
# Create user and set permissions
|
||||
RUN useradd -m kb && chown -R kb /app
|
||||
RUN useradd -m kb
|
||||
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
|
||||
|
||||
USER kb
|
||||
|
||||
# Default queues to process
|
||||
|
28
docker/workers/unnest-table.lua
Normal file
28
docker/workers/unnest-table.lua
Normal file
@ -0,0 +1,28 @@
|
||||
local function rows(tbl)
|
||||
local r = pandoc.List()
|
||||
r:extend(tbl.head.rows)
|
||||
for _,b in ipairs(tbl.bodies) do r:extend(b.body) end
|
||||
r:extend(tbl.foot.rows)
|
||||
return r
|
||||
end
|
||||
|
||||
function Table(t)
|
||||
local newHead = pandoc.TableHead()
|
||||
for i,row in ipairs(t.head.rows) do
|
||||
for j,cell in ipairs(row.cells) do
|
||||
local inner = cell.contents[1]
|
||||
if inner and inner.t == 'Table' then
|
||||
local ins = rows(inner)
|
||||
local first = table.remove(ins,1)
|
||||
row.cells[j] = first.cells[1]
|
||||
newHead.rows:insert(row)
|
||||
newHead.rows:extend(ins)
|
||||
goto continue
|
||||
end
|
||||
end
|
||||
newHead.rows:insert(row)
|
||||
::continue::
|
||||
end
|
||||
t.head = newHead
|
||||
return t
|
||||
end
|
@ -1,4 +1,4 @@
|
||||
fastapi==0.112.2
|
||||
uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
python-multipart==0.0.9
|
||||
|
2
requirements-mcp.txt
Normal file
2
requirements-mcp.txt
Normal file
@ -0,0 +1,2 @@
|
||||
mcp==1.7.1
|
||||
httpx==0.25.1
|
@ -1,23 +1,50 @@
|
||||
"""
|
||||
FastAPI application for the knowledge base.
|
||||
"""
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db import get_scoped_session
|
||||
from memory.common.db.models import SourceItem
|
||||
import base64
|
||||
import io
|
||||
from collections import defaultdict
|
||||
import pathlib
|
||||
from typing import Annotated, List, Optional, Callable
|
||||
from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form
|
||||
from fastapi.responses import FileResponse
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import embedding, qdrant, extract, settings
|
||||
from memory.common.embedding import get_modality
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Knowledge Base API")
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Database session dependency"""
|
||||
db = get_scoped_session()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
class AnnotatedChunk(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
metadata: dict
|
||||
preview: Optional[str | None] = None
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
collection: str
|
||||
results: List[dict]
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
id: int
|
||||
size: int
|
||||
mime_type: str
|
||||
chunks: list[AnnotatedChunk]
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@ -26,25 +53,195 @@ def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/sources")
|
||||
def list_sources(
|
||||
tag: str = None,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db)
|
||||
def annotated_chunk(
|
||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||
if not previews and not isinstance(item, str):
|
||||
return None
|
||||
|
||||
if isinstance(item, Image.Image):
|
||||
buffer = io.BytesIO()
|
||||
format = item.format or "PNG"
|
||||
item.save(buffer, format=format)
|
||||
mime_type = f"image/{format.lower()}"
|
||||
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
||||
elif isinstance(item, bytes):
|
||||
return base64.b64encode(item).decode("utf-8")
|
||||
elif isinstance(item, str):
|
||||
return item
|
||||
else:
|
||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||
|
||||
metadata = search_result.payload or {}
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||
}
|
||||
return chunk.source, AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=search_result.score,
|
||||
metadata=metadata,
|
||||
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
||||
)
|
||||
|
||||
|
||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
items = defaultdict(list)
|
||||
for source, chunk in chunks:
|
||||
items[source].append(chunk)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
),
|
||||
content=source.content,
|
||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||
)
|
||||
for source, chunks in items.items()
|
||||
]
|
||||
|
||||
|
||||
def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[tuple[str, list[extract.Page]]],
|
||||
allowed_modalities: set[str],
|
||||
embedder: Callable,
|
||||
min_score: float = 0.0,
|
||||
limit: int = 10,
|
||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||
if not upload_data:
|
||||
return {}
|
||||
|
||||
chunks = [
|
||||
chunk
|
||||
for content_type, pages in upload_data
|
||||
if get_modality(content_type) in allowed_modalities
|
||||
for page in pages
|
||||
for chunk in page["contents"]
|
||||
]
|
||||
|
||||
if not chunks:
|
||||
return {}
|
||||
|
||||
vector = embedder(chunks, input_type="query")[0]
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for r in qdrant.search_vectors(
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in embedding.DEFAULT_COLLECTIONS
|
||||
}
|
||||
|
||||
|
||||
async def input_type(item: str | UploadFile) -> tuple[str, list[extract.Page]]:
|
||||
if not item:
|
||||
return "text/plain", []
|
||||
|
||||
if isinstance(item, str):
|
||||
return "text/plain", extract.extract_text(item)
|
||||
content_type = item.content_type or "application/octet-stream"
|
||||
return content_type, extract.extract_content(content_type, await item.read())
|
||||
|
||||
|
||||
@app.post("/search", response_model=list[SearchResult])
|
||||
async def search(
|
||||
query: Optional[str] = Form(None),
|
||||
previews: Optional[bool] = Form(False),
|
||||
modalities: Annotated[list[str], Query()] = [],
|
||||
files: list[UploadFile] = File([]),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
min_text_score: float = Query(0.5, ge=0.0, le=1.0),
|
||||
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||
):
|
||||
"""List source items, optionally filtered by tag"""
|
||||
query = db.query(SourceItem)
|
||||
|
||||
if tag:
|
||||
query = query.filter(SourceItem.tags.contains([tag]))
|
||||
|
||||
return query.limit(limit).all()
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- query: Optional text search query
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- files: Optional files to include in the search context
|
||||
- limit: Maximum number of results per modality
|
||||
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
upload_data = [await input_type(item) for item in [query, *files]]
|
||||
logger.error(
|
||||
f"Querying chunks for {modalities}, query: {query}, previews: {previews}"
|
||||
)
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
allowed_modalities = set(modalities or embedding.DEFAULT_COLLECTIONS.keys())
|
||||
text_results = query_chunks(
|
||||
client,
|
||||
upload_data,
|
||||
allowed_modalities & embedding.TEXT_COLLECTIONS,
|
||||
embedding.embed_text,
|
||||
min_score=min_text_score,
|
||||
limit=limit,
|
||||
)
|
||||
multimodal_results = query_chunks(
|
||||
client,
|
||||
upload_data,
|
||||
allowed_modalities,
|
||||
embedding.embed_mixed,
|
||||
min_score=min_multimodal_score,
|
||||
limit=limit,
|
||||
)
|
||||
search_results = {
|
||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
||||
for k in allowed_modalities
|
||||
}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
|
||||
results = group_chunks(
|
||||
[
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
)
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||
|
||||
|
||||
@app.get("/sources/{source_id}")
|
||||
def get_source(source_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific source by ID"""
|
||||
source = db.query(SourceItem).filter(SourceItem.id == source_id).first()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
return source
|
||||
@app.get("/files/{path:path}")
|
||||
def get_file_by_path(path: str):
|
||||
"""
|
||||
Fetch a file by its path
|
||||
|
||||
Parameters:
|
||||
- path: Path of the file to fetch (relative to FILE_STORAGE_DIR)
|
||||
|
||||
Returns:
|
||||
- The file as a download
|
||||
"""
|
||||
# Sanitize the path to prevent directory traversal
|
||||
sanitized_path = path.lstrip("/")
|
||||
if ".." in sanitized_path:
|
||||
raise HTTPException(status_code=400, detail="Invalid path")
|
||||
|
||||
file_path = pathlib.Path(settings.FILE_STORAGE_DIR) / sanitized_path
|
||||
|
||||
# Check if the file exists on disk
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
raise HTTPException(status_code=404, detail=f"File not found at path: {path}")
|
||||
|
||||
return FileResponse(path=file_path, filename=file_path.name)
|
||||
|
@ -113,7 +113,7 @@ class Chunk(Base):
|
||||
if self.file_path is None:
|
||||
return [self.content]
|
||||
|
||||
path = pathlib.Path(self.file_path)
|
||||
path = pathlib.Path(self.file_path.replace("/app/", ""))
|
||||
if self.file_path.endswith("*"):
|
||||
files = list(path.parent.glob(path.name))
|
||||
else:
|
||||
@ -122,7 +122,8 @@ class Chunk(Base):
|
||||
items = []
|
||||
for file_path in files:
|
||||
if file_path.suffix == ".png":
|
||||
items.append(Image.open(file_path))
|
||||
if file_path.exists():
|
||||
items.append(Image.open(file_path))
|
||||
elif file_path.suffix == ".bin":
|
||||
items.append(file_path.read_bytes())
|
||||
else:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import pathlib
|
||||
import uuid
|
||||
from typing import Any, Iterable, Literal, TypedDict
|
||||
from typing import Any, Iterable, Literal, NotRequired, TypedDict
|
||||
|
||||
import voyageai
|
||||
from PIL import Image
|
||||
@ -27,20 +27,63 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
||||
class Collection(TypedDict):
|
||||
dimension: int
|
||||
distance: DistanceType
|
||||
on_disk: bool
|
||||
shards: int
|
||||
model: str
|
||||
on_disk: NotRequired[bool]
|
||||
shards: NotRequired[int]
|
||||
|
||||
|
||||
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {"dimension": 1024, "distance": "Cosine"},
|
||||
"chat": {"dimension": 1024, "distance": "Cosine"},
|
||||
"git": {"dimension": 1024, "distance": "Cosine"},
|
||||
"book": {"dimension": 1024, "distance": "Cosine"},
|
||||
"blog": {"dimension": 1024, "distance": "Cosine"},
|
||||
"text": {"dimension": 1024, "distance": "Cosine"},
|
||||
"mail": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"chat": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"git": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"book": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"blog": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"text": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
# Multimodal
|
||||
"photo": {"dimension": 1024, "distance": "Cosine"},
|
||||
"doc": {"dimension": 1024, "distance": "Cosine"},
|
||||
"photo": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
"doc": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
}
|
||||
TEXT_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in DEFAULT_COLLECTIONS.items()
|
||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||
}
|
||||
MULTIMODAL_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in DEFAULT_COLLECTIONS.items()
|
||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||
}
|
||||
|
||||
TYPES = {
|
||||
@ -69,22 +112,27 @@ def get_modality(mime_type: str) -> str:
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL
|
||||
chunks: list[extract.MulitmodalChunk],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
vo = voyageai.Client()
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
chunks, model=model, input_type="document"
|
||||
chunks, model=model, input_type=input_type
|
||||
).embeddings
|
||||
return vo.embed(chunks, model=model, input_type="document").embeddings
|
||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings
|
||||
|
||||
|
||||
def embed_text(
|
||||
texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL
|
||||
texts: list[str],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
chunks = [
|
||||
c
|
||||
for text in texts
|
||||
if isinstance(text, str)
|
||||
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
|
||||
if c.strip()
|
||||
]
|
||||
@ -92,7 +140,7 @@ def embed_text(
|
||||
return []
|
||||
|
||||
try:
|
||||
return embed_chunks(chunks, model)
|
||||
return embed_chunks(chunks, model, input_type)
|
||||
except voyageai.error.InvalidRequestError as e:
|
||||
logger.error(f"Error embedding text: {e}")
|
||||
logger.debug(f"Text: {texts}")
|
||||
@ -106,7 +154,9 @@ def embed_file(
|
||||
|
||||
|
||||
def embed_mixed(
|
||||
items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL
|
||||
items: list[extract.MulitmodalChunk],
|
||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
||||
if isinstance(item, str):
|
||||
@ -116,7 +166,7 @@ def embed_mixed(
|
||||
return [item]
|
||||
|
||||
chunks = [c for item in items for c in to_chunks(item)]
|
||||
return embed_chunks([chunks], model)
|
||||
return embed_chunks([chunks], model, input_type)
|
||||
|
||||
|
||||
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
||||
|
@ -7,6 +7,9 @@ import pymupdf # PyMuPDF
|
||||
from PIL import Image
|
||||
from typing import Any, TypedDict, Generator, Sequence
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MulitmodalChunk = Image.Image | str
|
||||
|
||||
@ -57,15 +60,32 @@ def docx_to_pdf(
|
||||
if output_path is None:
|
||||
output_path = docx_path.with_suffix(".pdf")
|
||||
|
||||
pypandoc.convert_file(str(docx_path), "pdf", outputfile=str(output_path))
|
||||
|
||||
return output_path
|
||||
# Now that we have all packages installed, try xelatex first as it has better Unicode support
|
||||
try:
|
||||
logger.info(f"Converting {docx_path} to PDF using xelatex")
|
||||
pypandoc.convert_file(
|
||||
str(docx_path),
|
||||
format="docx",
|
||||
to="pdf",
|
||||
outputfile=str(output_path),
|
||||
extra_args=[
|
||||
"--pdf-engine=xelatex",
|
||||
"--variable=geometry:margin=1in",
|
||||
"--lua-filter=/app/unnest-table.lua",
|
||||
],
|
||||
)
|
||||
logger.info(f"Successfully converted {docx_path} to PDF")
|
||||
return output_path
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting document to PDF: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def extract_docx(docx_path: pathlib.Path) -> list[Page]:
|
||||
def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]:
|
||||
"""Extract content from DOCX by converting to PDF first, then processing"""
|
||||
with as_file(docx_path) as file_path:
|
||||
pdf_path = docx_to_pdf(file_path)
|
||||
logger.info(f"Extracted PDF from {file_path}")
|
||||
return doc_to_images(pdf_path)
|
||||
|
||||
|
||||
@ -89,13 +109,17 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
|
||||
|
||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
logger.info(f"Extracting content from {mime_type}")
|
||||
if mime_type == "application/pdf":
|
||||
return doc_to_images(content)
|
||||
if isinstance(content, pathlib.Path) and mime_type in [
|
||||
if mime_type in [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/msword",
|
||||
]:
|
||||
return extract_docx(content)
|
||||
logger.info(f"Extracting content from {content}")
|
||||
pages = extract_docx(content)
|
||||
logger.info(f"Extracted {len(pages)} pages from {content}")
|
||||
return pages
|
||||
if mime_type.startswith("text/"):
|
||||
return extract_text(content)
|
||||
if mime_type.startswith("image/"):
|
||||
|
@ -25,6 +25,7 @@ class EmailMessage(TypedDict):
|
||||
sent_at: datetime | None
|
||||
body: str
|
||||
attachments: list[Attachment]
|
||||
hash: bytes
|
||||
|
||||
|
||||
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
||||
@ -33,10 +34,10 @@ RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
||||
def extract_recipients(msg: email.message.Message) -> list[str]:
|
||||
"""
|
||||
Extract email recipients from message headers.
|
||||
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
|
||||
Returns:
|
||||
List of recipient email addresses
|
||||
"""
|
||||
@ -52,10 +53,10 @@ def extract_recipients(msg: email.message.Message) -> list[str]:
|
||||
def extract_date(msg: email.message.Message) -> datetime | None:
|
||||
"""
|
||||
Parse date from email header.
|
||||
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
|
||||
Returns:
|
||||
Parsed datetime or None if parsing failed
|
||||
"""
|
||||
@ -70,18 +71,18 @@ def extract_date(msg: email.message.Message) -> datetime | None:
|
||||
def extract_body(msg: email.message.Message) -> str:
|
||||
"""
|
||||
Extract plain text body from email message.
|
||||
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
|
||||
Returns:
|
||||
Plain text body content
|
||||
"""
|
||||
body = ""
|
||||
|
||||
|
||||
if not msg.is_multipart():
|
||||
try:
|
||||
return msg.get_payload(decode=True).decode(errors='replace')
|
||||
return msg.get_payload(decode=True).decode(errors="replace")
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message body: {str(e)}")
|
||||
return ""
|
||||
@ -89,10 +90,10 @@ def extract_body(msg: email.message.Message) -> str:
|
||||
for part in msg.walk():
|
||||
content_type = part.get_content_type()
|
||||
content_disposition = str(part.get("Content-Disposition", ""))
|
||||
|
||||
|
||||
if content_type == "text/plain" and "attachment" not in content_disposition:
|
||||
try:
|
||||
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
|
||||
body += part.get_payload(decode=True).decode(errors="replace") + "\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message part: {str(e)}")
|
||||
return body
|
||||
@ -101,10 +102,10 @@ def extract_body(msg: email.message.Message) -> str:
|
||||
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
||||
"""
|
||||
Extract attachment metadata and content from email.
|
||||
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
|
||||
Returns:
|
||||
List of attachment dictionaries with metadata and content
|
||||
"""
|
||||
@ -120,14 +121,18 @@ def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
||||
if filename := part.get_filename():
|
||||
try:
|
||||
content = part.get_payload(decode=True)
|
||||
attachments.append({
|
||||
"filename": filename,
|
||||
"content_type": part.get_content_type(),
|
||||
"size": len(content),
|
||||
"content": content
|
||||
})
|
||||
attachments.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"content_type": part.get_content_type(),
|
||||
"size": len(content),
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting attachment content for {filename}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error extracting attachment content for {filename}: {str(e)}"
|
||||
)
|
||||
|
||||
return attachments
|
||||
|
||||
@ -135,13 +140,13 @@ def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
||||
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
|
||||
"""
|
||||
Compute a SHA-256 hash of message content.
|
||||
|
||||
|
||||
Args:
|
||||
msg_id: Message ID
|
||||
subject: Email subject
|
||||
sender: Sender email
|
||||
body: Message body
|
||||
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as bytes
|
||||
"""
|
||||
@ -152,10 +157,10 @@ def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> b
|
||||
def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
|
||||
"""
|
||||
Parse raw email into structured data.
|
||||
|
||||
|
||||
Args:
|
||||
raw_email: Raw email content as string
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with parsed email data
|
||||
"""
|
||||
@ -164,7 +169,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
|
||||
subject = msg.get("Subject", "")
|
||||
from_ = msg.get("From", "")
|
||||
body = extract_body(msg)
|
||||
|
||||
|
||||
return EmailMessage(
|
||||
message_id=message_id,
|
||||
subject=subject,
|
||||
@ -173,5 +178,5 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
|
||||
sent_at=extract_date(msg),
|
||||
body=body,
|
||||
attachments=extract_attachments(msg),
|
||||
hash=compute_message_hash(message_id, subject, from_, body)
|
||||
hash=compute_message_hash(message_id, subject, from_, body),
|
||||
)
|
||||
|
@ -5,15 +5,22 @@ import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from memory.common import settings
|
||||
from memory.common.embedding import Collection, DEFAULT_COLLECTIONS, DistanceType, Vector
|
||||
from memory.common.embedding import (
|
||||
Collection,
|
||||
DEFAULT_COLLECTIONS,
|
||||
DistanceType,
|
||||
Vector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_qdrant_client() -> qdrant_client.QdrantClient:
|
||||
"""Create and return a Qdrant client using environment configuration."""
|
||||
logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}")
|
||||
|
||||
logger.info(
|
||||
f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}"
|
||||
)
|
||||
|
||||
return qdrant_client.QdrantClient(
|
||||
host=settings.QDRANT_HOST,
|
||||
port=settings.QDRANT_PORT,
|
||||
@ -34,7 +41,7 @@ def ensure_collection_exists(
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure a collection exists with the specified parameters.
|
||||
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
@ -42,7 +49,7 @@ def ensure_collection_exists(
|
||||
distance: Distance metric (Cosine, Dot, Euclidean)
|
||||
on_disk: Whether to store vectors on disk
|
||||
shards: Number of shards for the collection
|
||||
|
||||
|
||||
Returns:
|
||||
True if the collection was created, False if it already existed
|
||||
"""
|
||||
@ -61,21 +68,23 @@ def ensure_collection_exists(
|
||||
on_disk_payload=on_disk,
|
||||
shard_number=shards,
|
||||
)
|
||||
|
||||
|
||||
# Create common payload indexes
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name="tags",
|
||||
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
||||
)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None:
|
||||
def initialize_collections(
|
||||
client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None
|
||||
) -> None:
|
||||
"""
|
||||
Initialize all required collections in Qdrant.
|
||||
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collections: Dictionary mapping collection names to their parameters.
|
||||
@ -83,7 +92,7 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict
|
||||
"""
|
||||
if collections is None:
|
||||
collections = DEFAULT_COLLECTIONS
|
||||
|
||||
|
||||
logger.info(f"Initializing collections:")
|
||||
for name, params in collections.items():
|
||||
logger.info(f" - {name}")
|
||||
@ -99,7 +108,7 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict
|
||||
|
||||
def setup_qdrant() -> qdrant_client.QdrantClient:
|
||||
"""Get a Qdrant client and initialize collections.
|
||||
|
||||
|
||||
Returns:
|
||||
Configured Qdrant client
|
||||
"""
|
||||
@ -109,14 +118,14 @@ def setup_qdrant() -> qdrant_client.QdrantClient:
|
||||
|
||||
|
||||
def upsert_vectors(
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
ids: list[str],
|
||||
vectors: list[Vector],
|
||||
payloads: list[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Upsert vectors into a collection.
|
||||
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
@ -126,7 +135,7 @@ def upsert_vectors(
|
||||
"""
|
||||
if payloads is None:
|
||||
payloads = [{} for _ in ids]
|
||||
|
||||
|
||||
points = [
|
||||
qdrant_models.PointStruct(
|
||||
id=id_str,
|
||||
@ -135,12 +144,12 @@ def upsert_vectors(
|
||||
)
|
||||
for id_str, vector, payload in zip(ids, vectors, payloads)
|
||||
]
|
||||
|
||||
|
||||
client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
|
||||
|
||||
|
||||
@ -152,21 +161,21 @@ def search_vectors(
|
||||
limit: int = 10,
|
||||
) -> list[qdrant_models.ScoredPoint]:
|
||||
"""Search for similar vectors in a collection.
|
||||
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
query_vector: Query vector
|
||||
filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}})
|
||||
limit: Maximum number of results to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of scored points
|
||||
"""
|
||||
filter_obj = None
|
||||
if filter_params:
|
||||
filter_obj = qdrant_models.Filter(**filter_params)
|
||||
|
||||
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
@ -182,7 +191,7 @@ def delete_vectors(
|
||||
) -> None:
|
||||
"""
|
||||
Delete vectors from a collection.
|
||||
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
@ -194,18 +203,20 @@ def delete_vectors(
|
||||
points=ids,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
|
||||
|
||||
|
||||
def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict:
|
||||
def get_collection_info(
|
||||
client: qdrant_client.QdrantClient, collection_name: str
|
||||
) -> dict:
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with collection information
|
||||
"""
|
||||
|
@ -52,7 +52,6 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
||||
|
||||
# Worker settings
|
||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
||||
EMAIL_SYNC_INTERVAL = 60
|
||||
|
||||
|
||||
# Embedding settings
|
||||
|
58
src/memory/mcp/server.py
Normal file
58
src/memory/mcp/server.py
Normal file
@ -0,0 +1,58 @@
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Any
|
||||
from fastapi import UploadFile
|
||||
import httpx
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
SERVER = "http://localhost:8000"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mcp = FastMCP("memory")
|
||||
|
||||
|
||||
async def make_request(
|
||||
path: str,
|
||||
method: str,
|
||||
data: dict | None = None,
|
||||
json: dict | None = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
) -> httpx.Response:
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.request(
|
||||
method, f"{SERVER}/{path}", data=data, json=json, files=files
|
||||
)
|
||||
|
||||
|
||||
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
|
||||
return await make_request(path, "POST", data=data)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(
|
||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||
) -> list[dict[str, Any]]:
|
||||
logger.error(f"Searching for {query}")
|
||||
resp = await post_data(
|
||||
"search",
|
||||
{
|
||||
"query": query,
|
||||
"previews": previews,
|
||||
"modalities": modalities,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize and run the server
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--server", type=str)
|
||||
args = args.parse_args()
|
||||
|
||||
SERVER = args.server
|
||||
|
||||
mcp.run(transport=args.transport)
|
@ -124,13 +124,14 @@ def create_mail_message(
|
||||
folder=folder,
|
||||
)
|
||||
|
||||
if parsed_email["attachments"]:
|
||||
mail_message.attachments = process_attachments(
|
||||
parsed_email["attachments"], mail_message
|
||||
)
|
||||
|
||||
db_session.add(mail_message)
|
||||
|
||||
if parsed_email["attachments"]:
|
||||
attachments = process_attachments(parsed_email["attachments"], mail_message)
|
||||
db_session.add_all(attachments)
|
||||
mail_message.attachments = attachments
|
||||
|
||||
db_session.add(mail_message)
|
||||
return mail_message
|
||||
|
||||
|
||||
@ -171,13 +172,11 @@ def check_message_exists(
|
||||
account = db.query(EmailAccount).get(account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return None
|
||||
return False
|
||||
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
|
||||
# Use server-provided message ID if missing
|
||||
if not parsed_email["message_id"]:
|
||||
parsed_email["message_id"] = f"generated-{message_id}"
|
||||
if "szczepalins" in raw_email.lower():
|
||||
print(parsed_email["message_id"])
|
||||
|
||||
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
||||
|
||||
|
@ -58,6 +58,7 @@ def process_message(
|
||||
db, account.tags, folder, raw_email, message_id
|
||||
)
|
||||
|
||||
db.flush()
|
||||
vectorize_email(mail_message)
|
||||
db.commit()
|
||||
|
||||
@ -65,12 +66,16 @@ def process_message(
|
||||
logger.info("Chunks:")
|
||||
for chunk in mail_message.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
for attachment in mail_message.attachments:
|
||||
logger.info(f" - Attachment {attachment.id}")
|
||||
for chunk in attachment.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
|
||||
return mail_message.id
|
||||
|
||||
|
||||
@app.task(name=SYNC_ACCOUNT)
|
||||
def sync_account(account_id: int) -> dict:
|
||||
def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
"""
|
||||
Synchronize emails from a specific account.
|
||||
|
||||
@ -80,7 +85,7 @@ def sync_account(account_id: int) -> dict:
|
||||
Returns:
|
||||
dict with stats about the sync operation
|
||||
"""
|
||||
logger.info(f"Syncing account {account_id}")
|
||||
logger.info(f"Syncing account {account_id} since {since_date}")
|
||||
|
||||
with make_session() as db:
|
||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||
@ -88,8 +93,11 @@ def sync_account(account_id: int) -> dict:
|
||||
logger.warning(f"Account {account_id} not found or inactive")
|
||||
return {"error": "Account not found or inactive"}
|
||||
|
||||
folders_to_process = account.folders or ["INBOX"]
|
||||
since_date = account.last_sync_at or datetime(1970, 1, 1)
|
||||
folders_to_process: list[str] = account.folders or ["INBOX"]
|
||||
if since_date:
|
||||
cutoff_date = datetime.fromisoformat(since_date)
|
||||
else:
|
||||
cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1)
|
||||
|
||||
messages_found = 0
|
||||
new_messages = 0
|
||||
@ -106,7 +114,7 @@ def sync_account(account_id: int) -> dict:
|
||||
with imap_connection(account) as conn:
|
||||
for folder in folders_to_process:
|
||||
folder_stats = process_folder(
|
||||
conn, folder, account, since_date, process_message_wrapper
|
||||
conn, folder, account, cutoff_date, process_message_wrapper
|
||||
)
|
||||
|
||||
messages_found += folder_stats["messages_found"]
|
||||
@ -121,6 +129,7 @@ def sync_account(account_id: int) -> dict:
|
||||
|
||||
return {
|
||||
"account": account.email_address,
|
||||
"since_date": cutoff_date.isoformat(),
|
||||
"folders_processed": len(folders_to_process),
|
||||
"messages_found": messages_found,
|
||||
"new_messages": new_messages,
|
||||
|
Loading…
x
Reference in New Issue
Block a user