mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 23:24:43 +02:00
initial server
This commit is contained in:
parent
fe15442a6d
commit
cf98d38bb8
@ -10,21 +10,28 @@ COPY src/ ./src/
|
|||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
libpq-dev gcc pandoc \
|
libpq-dev gcc pandoc \
|
||||||
texlive-full texlive-fonts-recommended texlive-plain-generic \
|
texlive-xetex texlive-fonts-recommended texlive-plain-generic \
|
||||||
|
texlive-lang-greek texlive-lang-cyrillic texlive-lang-european \
|
||||||
|
texlive-luatex texlive-latex-extra texlive-latex-recommended \
|
||||||
|
texlive-science texlive-fonts-extra \
|
||||||
|
fontconfig \
|
||||||
# For optional LibreOffice support (uncomment if needed)
|
# For optional LibreOffice support (uncomment if needed)
|
||||||
# libreoffice-writer \
|
# libreoffice-writer \
|
||||||
&& \
|
&& apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||||
pip install -e ".[workers]" && \
|
RUN pip install -e ".[workers]"
|
||||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Create and copy entrypoint script
|
# Create and copy entrypoint script
|
||||||
COPY docker/workers/entry.sh ./entry.sh
|
COPY docker/workers/entry.sh ./entry.sh
|
||||||
|
COPY docker/workers/unnest-table.lua ./unnest-table.lua
|
||||||
RUN chmod +x entry.sh
|
RUN chmod +x entry.sh
|
||||||
|
|
||||||
RUN mkdir -p /app/memory_files
|
RUN mkdir -p /app/memory_files
|
||||||
|
|
||||||
# Create user and set permissions
|
# Create user and set permissions
|
||||||
RUN useradd -m kb && chown -R kb /app
|
RUN useradd -m kb
|
||||||
|
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||||
|
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
|
||||||
|
|
||||||
USER kb
|
USER kb
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
|
28
docker/workers/unnest-table.lua
Normal file
28
docker/workers/unnest-table.lua
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
local function rows(tbl)
|
||||||
|
local r = pandoc.List()
|
||||||
|
r:extend(tbl.head.rows)
|
||||||
|
for _,b in ipairs(tbl.bodies) do r:extend(b.body) end
|
||||||
|
r:extend(tbl.foot.rows)
|
||||||
|
return r
|
||||||
|
end
|
||||||
|
|
||||||
|
function Table(t)
|
||||||
|
local newHead = pandoc.TableHead()
|
||||||
|
for i,row in ipairs(t.head.rows) do
|
||||||
|
for j,cell in ipairs(row.cells) do
|
||||||
|
local inner = cell.contents[1]
|
||||||
|
if inner and inner.t == 'Table' then
|
||||||
|
local ins = rows(inner)
|
||||||
|
local first = table.remove(ins,1)
|
||||||
|
row.cells[j] = first.cells[1]
|
||||||
|
newHead.rows:insert(row)
|
||||||
|
newHead.rows:extend(ins)
|
||||||
|
goto continue
|
||||||
|
end
|
||||||
|
end
|
||||||
|
newHead.rows:insert(row)
|
||||||
|
::continue::
|
||||||
|
end
|
||||||
|
t.head = newHead
|
||||||
|
return t
|
||||||
|
end
|
2
requirements-mcp.txt
Normal file
2
requirements-mcp.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
mcp==1.7.1
|
||||||
|
httpx==0.25.1
|
@ -1,23 +1,50 @@
|
|||||||
"""
|
"""
|
||||||
FastAPI application for the knowledge base.
|
FastAPI application for the knowledge base.
|
||||||
"""
|
"""
|
||||||
from fastapi import FastAPI, Depends, HTTPException
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from memory.common.db import get_scoped_session
|
import base64
|
||||||
from memory.common.db.models import SourceItem
|
import io
|
||||||
|
from collections import defaultdict
|
||||||
|
import pathlib
|
||||||
|
from typing import Annotated, List, Optional, Callable
|
||||||
|
from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
import qdrant_client
|
||||||
|
from qdrant_client.http import models as qdrant_models
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from memory.common import embedding, qdrant, extract, settings
|
||||||
|
from memory.common.embedding import get_modality
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API")
|
app = FastAPI(title="Knowledge Base API")
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
class AnnotatedChunk(BaseModel):
|
||||||
"""Database session dependency"""
|
id: str
|
||||||
db = get_scoped_session()
|
score: float
|
||||||
try:
|
metadata: dict
|
||||||
yield db
|
preview: Optional[str | None] = None
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
collection: str
|
||||||
|
results: List[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
id: int
|
||||||
|
size: int
|
||||||
|
mime_type: str
|
||||||
|
chunks: list[AnnotatedChunk]
|
||||||
|
content: Optional[str] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@ -26,25 +53,195 @@ def health_check():
|
|||||||
return {"status": "healthy"}
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sources")
|
def annotated_chunk(
|
||||||
def list_sources(
|
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||||
tag: str = None,
|
) -> tuple[SourceItem, AnnotatedChunk]:
|
||||||
limit: int = 100,
|
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||||
db: Session = Depends(get_db)
|
if not previews and not isinstance(item, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(item, Image.Image):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
format = item.format or "PNG"
|
||||||
|
item.save(buffer, format=format)
|
||||||
|
mime_type = f"image/{format.lower()}"
|
||||||
|
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
||||||
|
elif isinstance(item, bytes):
|
||||||
|
return base64.b64encode(item).decode("utf-8")
|
||||||
|
elif isinstance(item, str):
|
||||||
|
return item
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||||
|
|
||||||
|
metadata = search_result.payload or {}
|
||||||
|
metadata = {
|
||||||
|
k: v
|
||||||
|
for k, v in metadata.items()
|
||||||
|
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||||
|
}
|
||||||
|
return chunk.source, AnnotatedChunk(
|
||||||
|
id=str(chunk.id),
|
||||||
|
score=search_result.score,
|
||||||
|
metadata=metadata,
|
||||||
|
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
||||||
|
items = defaultdict(list)
|
||||||
|
for source, chunk in chunks:
|
||||||
|
items[source].append(chunk)
|
||||||
|
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
id=source.id,
|
||||||
|
size=source.size,
|
||||||
|
mime_type=source.mime_type,
|
||||||
|
filename=source.filename
|
||||||
|
and source.filename.replace(
|
||||||
|
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||||
|
),
|
||||||
|
content=source.content,
|
||||||
|
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||||
|
)
|
||||||
|
for source, chunks in items.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def query_chunks(
|
||||||
|
client: qdrant_client.QdrantClient,
|
||||||
|
upload_data: list[tuple[str, list[extract.Page]]],
|
||||||
|
allowed_modalities: set[str],
|
||||||
|
embedder: Callable,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||||
|
if not upload_data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
chunk
|
||||||
|
for content_type, pages in upload_data
|
||||||
|
if get_modality(content_type) in allowed_modalities
|
||||||
|
for page in pages
|
||||||
|
for chunk in page["contents"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
vector = embedder(chunks, input_type="query")[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
collection: [
|
||||||
|
r
|
||||||
|
for r in qdrant.search_vectors(
|
||||||
|
client=client,
|
||||||
|
collection_name=collection,
|
||||||
|
query_vector=vector,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
if r.score >= min_score
|
||||||
|
]
|
||||||
|
for collection in embedding.DEFAULT_COLLECTIONS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def input_type(item: str | UploadFile) -> tuple[str, list[extract.Page]]:
|
||||||
|
if not item:
|
||||||
|
return "text/plain", []
|
||||||
|
|
||||||
|
if isinstance(item, str):
|
||||||
|
return "text/plain", extract.extract_text(item)
|
||||||
|
content_type = item.content_type or "application/octet-stream"
|
||||||
|
return content_type, extract.extract_content(content_type, await item.read())
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/search", response_model=list[SearchResult])
|
||||||
|
async def search(
|
||||||
|
query: Optional[str] = Form(None),
|
||||||
|
previews: Optional[bool] = Form(False),
|
||||||
|
modalities: Annotated[list[str], Query()] = [],
|
||||||
|
files: list[UploadFile] = File([]),
|
||||||
|
limit: int = Query(10, ge=1, le=100),
|
||||||
|
min_text_score: float = Query(0.5, ge=0.0, le=1.0),
|
||||||
|
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||||
):
|
):
|
||||||
"""List source items, optionally filtered by tag"""
|
"""
|
||||||
query = db.query(SourceItem)
|
Search across knowledge base using text query and optional files.
|
||||||
|
|
||||||
if tag:
|
Parameters:
|
||||||
query = query.filter(SourceItem.tags.contains([tag]))
|
- query: Optional text search query
|
||||||
|
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||||
|
- files: Optional files to include in the search context
|
||||||
|
- limit: Maximum number of results per modality
|
||||||
|
|
||||||
return query.limit(limit).all()
|
Returns:
|
||||||
|
- List of search results sorted by score
|
||||||
|
"""
|
||||||
|
upload_data = [await input_type(item) for item in [query, *files]]
|
||||||
|
logger.error(
|
||||||
|
f"Querying chunks for {modalities}, query: {query}, previews: {previews}"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
allowed_modalities = set(modalities or embedding.DEFAULT_COLLECTIONS.keys())
|
||||||
|
text_results = query_chunks(
|
||||||
|
client,
|
||||||
|
upload_data,
|
||||||
|
allowed_modalities & embedding.TEXT_COLLECTIONS,
|
||||||
|
embedding.embed_text,
|
||||||
|
min_score=min_text_score,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
multimodal_results = query_chunks(
|
||||||
|
client,
|
||||||
|
upload_data,
|
||||||
|
allowed_modalities,
|
||||||
|
embedding.embed_mixed,
|
||||||
|
min_score=min_multimodal_score,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
search_results = {
|
||||||
|
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
||||||
|
for k in allowed_modalities
|
||||||
|
}
|
||||||
|
|
||||||
|
found_chunks = {
|
||||||
|
str(r.id): r for results in search_results.values() for r in results
|
||||||
|
}
|
||||||
|
with make_session() as db:
|
||||||
|
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||||
|
|
||||||
|
results = group_chunks(
|
||||||
|
[
|
||||||
|
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sources/{source_id}")
|
@app.get("/files/{path:path}")
|
||||||
def get_source(source_id: int, db: Session = Depends(get_db)):
|
def get_file_by_path(path: str):
|
||||||
"""Get a specific source by ID"""
|
"""
|
||||||
source = db.query(SourceItem).filter(SourceItem.id == source_id).first()
|
Fetch a file by its path
|
||||||
if not source:
|
|
||||||
raise HTTPException(status_code=404, detail="Source not found")
|
Parameters:
|
||||||
return source
|
- path: Path of the file to fetch (relative to FILE_STORAGE_DIR)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The file as a download
|
||||||
|
"""
|
||||||
|
# Sanitize the path to prevent directory traversal
|
||||||
|
sanitized_path = path.lstrip("/")
|
||||||
|
if ".." in sanitized_path:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid path")
|
||||||
|
|
||||||
|
file_path = pathlib.Path(settings.FILE_STORAGE_DIR) / sanitized_path
|
||||||
|
|
||||||
|
# Check if the file exists on disk
|
||||||
|
if not file_path.exists() or not file_path.is_file():
|
||||||
|
raise HTTPException(status_code=404, detail=f"File not found at path: {path}")
|
||||||
|
|
||||||
|
return FileResponse(path=file_path, filename=file_path.name)
|
||||||
|
@ -113,7 +113,7 @@ class Chunk(Base):
|
|||||||
if self.file_path is None:
|
if self.file_path is None:
|
||||||
return [self.content]
|
return [self.content]
|
||||||
|
|
||||||
path = pathlib.Path(self.file_path)
|
path = pathlib.Path(self.file_path.replace("/app/", ""))
|
||||||
if self.file_path.endswith("*"):
|
if self.file_path.endswith("*"):
|
||||||
files = list(path.parent.glob(path.name))
|
files = list(path.parent.glob(path.name))
|
||||||
else:
|
else:
|
||||||
@ -122,6 +122,7 @@ class Chunk(Base):
|
|||||||
items = []
|
items = []
|
||||||
for file_path in files:
|
for file_path in files:
|
||||||
if file_path.suffix == ".png":
|
if file_path.suffix == ".png":
|
||||||
|
if file_path.exists():
|
||||||
items.append(Image.open(file_path))
|
items.append(Image.open(file_path))
|
||||||
elif file_path.suffix == ".bin":
|
elif file_path.suffix == ".bin":
|
||||||
items.append(file_path.read_bytes())
|
items.append(file_path.read_bytes())
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Iterable, Literal, TypedDict
|
from typing import Any, Iterable, Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
import voyageai
|
import voyageai
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -27,20 +27,63 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
|||||||
class Collection(TypedDict):
|
class Collection(TypedDict):
|
||||||
dimension: int
|
dimension: int
|
||||||
distance: DistanceType
|
distance: DistanceType
|
||||||
on_disk: bool
|
model: str
|
||||||
shards: int
|
on_disk: NotRequired[bool]
|
||||||
|
shards: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
||||||
"mail": {"dimension": 1024, "distance": "Cosine"},
|
"mail": {
|
||||||
"chat": {"dimension": 1024, "distance": "Cosine"},
|
"dimension": 1024,
|
||||||
"git": {"dimension": 1024, "distance": "Cosine"},
|
"distance": "Cosine",
|
||||||
"book": {"dimension": 1024, "distance": "Cosine"},
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
"blog": {"dimension": 1024, "distance": "Cosine"},
|
},
|
||||||
"text": {"dimension": 1024, "distance": "Cosine"},
|
"chat": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"git": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"book": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"blog": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
# Multimodal
|
# Multimodal
|
||||||
"photo": {"dimension": 1024, "distance": "Cosine"},
|
"photo": {
|
||||||
"doc": {"dimension": 1024, "distance": "Cosine"},
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"doc": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
TEXT_COLLECTIONS = {
|
||||||
|
coll
|
||||||
|
for coll, params in DEFAULT_COLLECTIONS.items()
|
||||||
|
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||||
|
}
|
||||||
|
MULTIMODAL_COLLECTIONS = {
|
||||||
|
coll
|
||||||
|
for coll, params in DEFAULT_COLLECTIONS.items()
|
||||||
|
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPES = {
|
TYPES = {
|
||||||
@ -69,22 +112,27 @@ def get_modality(mime_type: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def embed_chunks(
|
def embed_chunks(
|
||||||
chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL
|
chunks: list[extract.MulitmodalChunk],
|
||||||
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
vo = voyageai.Client()
|
vo = voyageai.Client()
|
||||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||||
return vo.multimodal_embed(
|
return vo.multimodal_embed(
|
||||||
chunks, model=model, input_type="document"
|
chunks, model=model, input_type=input_type
|
||||||
).embeddings
|
).embeddings
|
||||||
return vo.embed(chunks, model=model, input_type="document").embeddings
|
return vo.embed(chunks, model=model, input_type=input_type).embeddings
|
||||||
|
|
||||||
|
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL
|
texts: list[str],
|
||||||
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
chunks = [
|
chunks = [
|
||||||
c
|
c
|
||||||
for text in texts
|
for text in texts
|
||||||
|
if isinstance(text, str)
|
||||||
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
|
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
|
||||||
if c.strip()
|
if c.strip()
|
||||||
]
|
]
|
||||||
@ -92,7 +140,7 @@ def embed_text(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return embed_chunks(chunks, model)
|
return embed_chunks(chunks, model, input_type)
|
||||||
except voyageai.error.InvalidRequestError as e:
|
except voyageai.error.InvalidRequestError as e:
|
||||||
logger.error(f"Error embedding text: {e}")
|
logger.error(f"Error embedding text: {e}")
|
||||||
logger.debug(f"Text: {texts}")
|
logger.debug(f"Text: {texts}")
|
||||||
@ -106,7 +154,9 @@ def embed_file(
|
|||||||
|
|
||||||
|
|
||||||
def embed_mixed(
|
def embed_mixed(
|
||||||
items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL
|
items: list[extract.MulitmodalChunk],
|
||||||
|
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
@ -116,7 +166,7 @@ def embed_mixed(
|
|||||||
return [item]
|
return [item]
|
||||||
|
|
||||||
chunks = [c for item in items for c in to_chunks(item)]
|
chunks = [c for item in items for c in to_chunks(item)]
|
||||||
return embed_chunks([chunks], model)
|
return embed_chunks([chunks], model, input_type)
|
||||||
|
|
||||||
|
|
||||||
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
||||||
|
@ -7,6 +7,9 @@ import pymupdf # PyMuPDF
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Any, TypedDict, Generator, Sequence
|
from typing import Any, TypedDict, Generator, Sequence
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MulitmodalChunk = Image.Image | str
|
MulitmodalChunk = Image.Image | str
|
||||||
|
|
||||||
@ -57,15 +60,32 @@ def docx_to_pdf(
|
|||||||
if output_path is None:
|
if output_path is None:
|
||||||
output_path = docx_path.with_suffix(".pdf")
|
output_path = docx_path.with_suffix(".pdf")
|
||||||
|
|
||||||
pypandoc.convert_file(str(docx_path), "pdf", outputfile=str(output_path))
|
# Now that we have all packages installed, try xelatex first as it has better Unicode support
|
||||||
|
try:
|
||||||
|
logger.info(f"Converting {docx_path} to PDF using xelatex")
|
||||||
|
pypandoc.convert_file(
|
||||||
|
str(docx_path),
|
||||||
|
format="docx",
|
||||||
|
to="pdf",
|
||||||
|
outputfile=str(output_path),
|
||||||
|
extra_args=[
|
||||||
|
"--pdf-engine=xelatex",
|
||||||
|
"--variable=geometry:margin=1in",
|
||||||
|
"--lua-filter=/app/unnest-table.lua",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully converted {docx_path} to PDF")
|
||||||
return output_path
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error converting document to PDF: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def extract_docx(docx_path: pathlib.Path) -> list[Page]:
|
def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]:
|
||||||
"""Extract content from DOCX by converting to PDF first, then processing"""
|
"""Extract content from DOCX by converting to PDF first, then processing"""
|
||||||
with as_file(docx_path) as file_path:
|
with as_file(docx_path) as file_path:
|
||||||
pdf_path = docx_to_pdf(file_path)
|
pdf_path = docx_to_pdf(file_path)
|
||||||
|
logger.info(f"Extracted PDF from {file_path}")
|
||||||
return doc_to_images(pdf_path)
|
return doc_to_images(pdf_path)
|
||||||
|
|
||||||
|
|
||||||
@ -89,13 +109,17 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
|||||||
|
|
||||||
|
|
||||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||||
|
logger.info(f"Extracting content from {mime_type}")
|
||||||
if mime_type == "application/pdf":
|
if mime_type == "application/pdf":
|
||||||
return doc_to_images(content)
|
return doc_to_images(content)
|
||||||
if isinstance(content, pathlib.Path) and mime_type in [
|
if mime_type in [
|
||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
"application/msword",
|
"application/msword",
|
||||||
]:
|
]:
|
||||||
return extract_docx(content)
|
logger.info(f"Extracting content from {content}")
|
||||||
|
pages = extract_docx(content)
|
||||||
|
logger.info(f"Extracted {len(pages)} pages from {content}")
|
||||||
|
return pages
|
||||||
if mime_type.startswith("text/"):
|
if mime_type.startswith("text/"):
|
||||||
return extract_text(content)
|
return extract_text(content)
|
||||||
if mime_type.startswith("image/"):
|
if mime_type.startswith("image/"):
|
||||||
|
@ -25,6 +25,7 @@ class EmailMessage(TypedDict):
|
|||||||
sent_at: datetime | None
|
sent_at: datetime | None
|
||||||
body: str
|
body: str
|
||||||
attachments: list[Attachment]
|
attachments: list[Attachment]
|
||||||
|
hash: bytes
|
||||||
|
|
||||||
|
|
||||||
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
||||||
@ -81,7 +82,7 @@ def extract_body(msg: email.message.Message) -> str:
|
|||||||
|
|
||||||
if not msg.is_multipart():
|
if not msg.is_multipart():
|
||||||
try:
|
try:
|
||||||
return msg.get_payload(decode=True).decode(errors='replace')
|
return msg.get_payload(decode=True).decode(errors="replace")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error decoding message body: {str(e)}")
|
logger.error(f"Error decoding message body: {str(e)}")
|
||||||
return ""
|
return ""
|
||||||
@ -92,7 +93,7 @@ def extract_body(msg: email.message.Message) -> str:
|
|||||||
|
|
||||||
if content_type == "text/plain" and "attachment" not in content_disposition:
|
if content_type == "text/plain" and "attachment" not in content_disposition:
|
||||||
try:
|
try:
|
||||||
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
|
body += part.get_payload(decode=True).decode(errors="replace") + "\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error decoding message part: {str(e)}")
|
logger.error(f"Error decoding message part: {str(e)}")
|
||||||
return body
|
return body
|
||||||
@ -120,14 +121,18 @@ def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
|||||||
if filename := part.get_filename():
|
if filename := part.get_filename():
|
||||||
try:
|
try:
|
||||||
content = part.get_payload(decode=True)
|
content = part.get_payload(decode=True)
|
||||||
attachments.append({
|
attachments.append(
|
||||||
|
{
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"content_type": part.get_content_type(),
|
"content_type": part.get_content_type(),
|
||||||
"size": len(content),
|
"size": len(content),
|
||||||
"content": content
|
"content": content,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error extracting attachment content for {filename}: {str(e)}")
|
logger.error(
|
||||||
|
f"Error extracting attachment content for {filename}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
return attachments
|
return attachments
|
||||||
|
|
||||||
@ -173,5 +178,5 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
|
|||||||
sent_at=extract_date(msg),
|
sent_at=extract_date(msg),
|
||||||
body=body,
|
body=body,
|
||||||
attachments=extract_attachments(msg),
|
attachments=extract_attachments(msg),
|
||||||
hash=compute_message_hash(message_id, subject, from_, body)
|
hash=compute_message_hash(message_id, subject, from_, body),
|
||||||
)
|
)
|
||||||
|
@ -5,14 +5,21 @@ import qdrant_client
|
|||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.embedding import Collection, DEFAULT_COLLECTIONS, DistanceType, Vector
|
from memory.common.embedding import (
|
||||||
|
Collection,
|
||||||
|
DEFAULT_COLLECTIONS,
|
||||||
|
DistanceType,
|
||||||
|
Vector,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_qdrant_client() -> qdrant_client.QdrantClient:
|
def get_qdrant_client() -> qdrant_client.QdrantClient:
|
||||||
"""Create and return a Qdrant client using environment configuration."""
|
"""Create and return a Qdrant client using environment configuration."""
|
||||||
logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}")
|
logger.info(
|
||||||
|
f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}"
|
||||||
|
)
|
||||||
|
|
||||||
return qdrant_client.QdrantClient(
|
return qdrant_client.QdrantClient(
|
||||||
host=settings.QDRANT_HOST,
|
host=settings.QDRANT_HOST,
|
||||||
@ -72,7 +79,9 @@ def ensure_collection_exists(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None:
|
def initialize_collections(
|
||||||
|
client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize all required collections in Qdrant.
|
Initialize all required collections in Qdrant.
|
||||||
|
|
||||||
@ -198,7 +207,9 @@ def delete_vectors(
|
|||||||
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
|
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
|
||||||
|
|
||||||
|
|
||||||
def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict:
|
def get_collection_info(
|
||||||
|
client: qdrant_client.QdrantClient, collection_name: str
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Get information about a collection.
|
Get information about a collection.
|
||||||
|
|
||||||
|
@ -52,7 +52,6 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
|||||||
|
|
||||||
# Worker settings
|
# Worker settings
|
||||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
||||||
EMAIL_SYNC_INTERVAL = 60
|
|
||||||
|
|
||||||
|
|
||||||
# Embedding settings
|
# Embedding settings
|
||||||
|
58
src/memory/mcp/server.py
Normal file
58
src/memory/mcp/server.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from fastapi import UploadFile
|
||||||
|
import httpx
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
SERVER = "http://localhost:8000"
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
mcp = FastMCP("memory")
|
||||||
|
|
||||||
|
|
||||||
|
async def make_request(
|
||||||
|
path: str,
|
||||||
|
method: str,
|
||||||
|
data: dict | None = None,
|
||||||
|
json: dict | None = None,
|
||||||
|
files: list[UploadFile] | None = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
return await client.request(
|
||||||
|
method, f"{SERVER}/{path}", data=data, json=json, files=files
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
|
||||||
|
return await make_request(path, "POST", data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search(
|
||||||
|
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
logger.error(f"Searching for {query}")
|
||||||
|
resp = await post_data(
|
||||||
|
"search",
|
||||||
|
{
|
||||||
|
"query": query,
|
||||||
|
"previews": previews,
|
||||||
|
"modalities": modalities,
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Initialize and run the server
|
||||||
|
args = argparse.ArgumentParser()
|
||||||
|
args.add_argument("--server", type=str)
|
||||||
|
args = args.parse_args()
|
||||||
|
|
||||||
|
SERVER = args.server
|
||||||
|
|
||||||
|
mcp.run(transport=args.transport)
|
@ -124,13 +124,14 @@ def create_mail_message(
|
|||||||
folder=folder,
|
folder=folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
if parsed_email["attachments"]:
|
|
||||||
mail_message.attachments = process_attachments(
|
|
||||||
parsed_email["attachments"], mail_message
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(mail_message)
|
db_session.add(mail_message)
|
||||||
|
|
||||||
|
if parsed_email["attachments"]:
|
||||||
|
attachments = process_attachments(parsed_email["attachments"], mail_message)
|
||||||
|
db_session.add_all(attachments)
|
||||||
|
mail_message.attachments = attachments
|
||||||
|
|
||||||
|
db_session.add(mail_message)
|
||||||
return mail_message
|
return mail_message
|
||||||
|
|
||||||
|
|
||||||
@ -171,13 +172,11 @@ def check_message_exists(
|
|||||||
account = db.query(EmailAccount).get(account_id)
|
account = db.query(EmailAccount).get(account_id)
|
||||||
if not account:
|
if not account:
|
||||||
logger.error(f"Account {account_id} not found")
|
logger.error(f"Account {account_id} not found")
|
||||||
return None
|
return False
|
||||||
|
|
||||||
parsed_email = parse_email_message(raw_email, message_id)
|
parsed_email = parse_email_message(raw_email, message_id)
|
||||||
|
if "szczepalins" in raw_email.lower():
|
||||||
# Use server-provided message ID if missing
|
print(parsed_email["message_id"])
|
||||||
if not parsed_email["message_id"]:
|
|
||||||
parsed_email["message_id"] = f"generated-{message_id}"
|
|
||||||
|
|
||||||
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@ def process_message(
|
|||||||
db, account.tags, folder, raw_email, message_id
|
db, account.tags, folder, raw_email, message_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
vectorize_email(mail_message)
|
vectorize_email(mail_message)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@ -65,12 +66,16 @@ def process_message(
|
|||||||
logger.info("Chunks:")
|
logger.info("Chunks:")
|
||||||
for chunk in mail_message.chunks:
|
for chunk in mail_message.chunks:
|
||||||
logger.info(f" - {chunk.id}")
|
logger.info(f" - {chunk.id}")
|
||||||
|
for attachment in mail_message.attachments:
|
||||||
|
logger.info(f" - Attachment {attachment.id}")
|
||||||
|
for chunk in attachment.chunks:
|
||||||
|
logger.info(f" - {chunk.id}")
|
||||||
|
|
||||||
return mail_message.id
|
return mail_message.id
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_ACCOUNT)
|
@app.task(name=SYNC_ACCOUNT)
|
||||||
def sync_account(account_id: int) -> dict:
|
def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Synchronize emails from a specific account.
|
Synchronize emails from a specific account.
|
||||||
|
|
||||||
@ -80,7 +85,7 @@ def sync_account(account_id: int) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
dict with stats about the sync operation
|
dict with stats about the sync operation
|
||||||
"""
|
"""
|
||||||
logger.info(f"Syncing account {account_id}")
|
logger.info(f"Syncing account {account_id} since {since_date}")
|
||||||
|
|
||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||||
@ -88,8 +93,11 @@ def sync_account(account_id: int) -> dict:
|
|||||||
logger.warning(f"Account {account_id} not found or inactive")
|
logger.warning(f"Account {account_id} not found or inactive")
|
||||||
return {"error": "Account not found or inactive"}
|
return {"error": "Account not found or inactive"}
|
||||||
|
|
||||||
folders_to_process = account.folders or ["INBOX"]
|
folders_to_process: list[str] = account.folders or ["INBOX"]
|
||||||
since_date = account.last_sync_at or datetime(1970, 1, 1)
|
if since_date:
|
||||||
|
cutoff_date = datetime.fromisoformat(since_date)
|
||||||
|
else:
|
||||||
|
cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1)
|
||||||
|
|
||||||
messages_found = 0
|
messages_found = 0
|
||||||
new_messages = 0
|
new_messages = 0
|
||||||
@ -106,7 +114,7 @@ def sync_account(account_id: int) -> dict:
|
|||||||
with imap_connection(account) as conn:
|
with imap_connection(account) as conn:
|
||||||
for folder in folders_to_process:
|
for folder in folders_to_process:
|
||||||
folder_stats = process_folder(
|
folder_stats = process_folder(
|
||||||
conn, folder, account, since_date, process_message_wrapper
|
conn, folder, account, cutoff_date, process_message_wrapper
|
||||||
)
|
)
|
||||||
|
|
||||||
messages_found += folder_stats["messages_found"]
|
messages_found += folder_stats["messages_found"]
|
||||||
@ -121,6 +129,7 @@ def sync_account(account_id: int) -> dict:
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"account": account.email_address,
|
"account": account.email_address,
|
||||||
|
"since_date": cutoff_date.isoformat(),
|
||||||
"folders_processed": len(folders_to_process),
|
"folders_processed": len(folders_to_process),
|
||||||
"messages_found": messages_found,
|
"messages_found": messages_found,
|
||||||
"new_messages": new_messages,
|
"new_messages": new_messages,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user