mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-29 14:16:09 +02:00
Compare commits
7 Commits
0551ddd30c
...
f8090634c7
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f8090634c7 | ||
![]() |
7ac16031bb | ||
![]() |
3d9f8ae55f | ||
![]() |
ac9bdb1dfc | ||
![]() |
69cf2844f9 | ||
![]() |
c7aa50347c | ||
![]() |
f2c24cca3b |
224
README.md
Normal file
224
README.md
Normal file
@ -0,0 +1,224 @@
|
||||
# Memory - Personal Knowledge Base
|
||||
|
||||
A personal knowledge base system that ingests, indexes, and provides semantic search over various content types including emails, documents, notes, web pages, and more. Features MCP (Model Context Protocol) integration for AI assistants to access and learn from your personal data.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-modal Content Ingestion**: Process emails, documents, ebooks, comics, web pages, and more
|
||||
- **Semantic Search**: Vector-based search across all your content with relevance scoring
|
||||
- **MCP Integration**: Direct integration with AI assistants via Model Context Protocol
|
||||
- **Observation System**: AI assistants can record and search long-term observations about user preferences and patterns
|
||||
- **Note Taking**: Create and organize markdown notes with full-text search
|
||||
- **User Management**: Multi-user support with authentication
|
||||
- **RESTful API**: Complete API for programmatic access
|
||||
- **Real-time Processing**: Celery-based background processing for content ingestion
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker and Docker Compose
|
||||
- Python 3.11+ (for tools)
|
||||
|
||||
### 1. Start the Development Environment
|
||||
|
||||
```bash
|
||||
# Clone the repository and navigate to it
|
||||
cd memory
|
||||
|
||||
# Start the core services (PostgreSQL, RabbitMQ, Qdrant)
|
||||
./dev.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
|
||||
- Start PostgreSQL (exposed on port 5432)
|
||||
- Start RabbitMQ with management interface
|
||||
- Start Qdrant vector database
|
||||
- Initialize the database schema
|
||||
|
||||
It will also generate secrets in `secrets` and make a basic `.env` file for you.
|
||||
|
||||
### 2. Start the Full Application
|
||||
|
||||
```bash
|
||||
# Start all services including API and workers
|
||||
docker-compose up -d
|
||||
|
||||
# Check that services are healthy
|
||||
docker-compose ps
|
||||
```
|
||||
|
||||
The API will be available at `http://localhost:8000`
|
||||
|
||||
The is also an admin interface at `http://localhost:8000/admin` where you can see what the database
|
||||
contains.
|
||||
|
||||
Because of how MCP can't yet handle basic auth,
|
||||
|
||||
## User Management
|
||||
|
||||
### Create a User
|
||||
|
||||
```bash
|
||||
# Install the package in development mode
|
||||
pip install -e ".[all]"
|
||||
|
||||
# Create a new user
|
||||
python tools/add_user.py --email user@example.com --password yourpassword --name "Your Name"
|
||||
```
|
||||
|
||||
### Notes synchronisation
|
||||
|
||||
You can set up notes to be automatically pushed to a git repo whenever they are modified.
|
||||
Run the following job to do so:
|
||||
|
||||
```bash
|
||||
python tools/run_celery_task.py notes setup-git-notes --origin ssh://git@github.com/some/repo.git --email bla@ble.com --name <user to send commits>
|
||||
```
|
||||
|
||||
For this to work you need to make sure you have set up the ssh keys in `secrets` (see the README.md
|
||||
in that folder), and you will need to add the public key that is generated there to your git server.
|
||||
|
||||
### Authentication
|
||||
|
||||
The API uses session-based authentication. Login via:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/auth/login \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"email": "user@example.com", "password": "yourpassword"}'
|
||||
```
|
||||
|
||||
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
||||
|
||||
## MCP Proxy Setup
|
||||
|
||||
Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect:
|
||||
|
||||
### Start the Proxy
|
||||
|
||||
```bash
|
||||
python tools/simple_proxy.py \
|
||||
--remote-server http://localhost:8000 \
|
||||
--email user@example.com \
|
||||
--password yourpassword \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
### Configure Your AI Assistant
|
||||
|
||||
Point your MCP-compatible AI assistant to `http://localhost:8080` instead of the direct API endpoint. The proxy will:
|
||||
|
||||
- Handle authentication automatically
|
||||
- Forward all requests to the main API
|
||||
- Add the session header to each request
|
||||
|
||||
### Example MCP Configuration
|
||||
|
||||
For Claude Desktop or other MCP clients, add to your configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"type": "streamable-http",
|
||||
"url": "http://localhost:8001/mcp",
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Available MCP Tools
|
||||
|
||||
When connected via MCP, AI assistants have access to:
|
||||
|
||||
- `search_knowledge_base()` - Search your stored content
|
||||
- `search_observations()` - Search recorded observations about you
|
||||
- `observe()` - Record new observations about your preferences/behavior
|
||||
- `create_note()` - Create and save notes
|
||||
- `note_files()` - List existing notes
|
||||
- `fetch_file()` - Read file contents
|
||||
- `get_all_tags()` - Get all content tags
|
||||
- `get_all_subjects()` - Get observation subjects
|
||||
|
||||
## Content Ingestion
|
||||
|
||||
### Via Workers
|
||||
|
||||
Content is processed asynchronously by Celery workers. Supported formats include:
|
||||
|
||||
- PDFs, DOCX, TXT files
|
||||
- Emails (mbox, EML formats)
|
||||
- Web pages (HTML)
|
||||
- Ebooks (EPUB, PDF)
|
||||
- Images with OCR
|
||||
- And more...
|
||||
|
||||
## Development
|
||||
|
||||
### Environment Setup
|
||||
|
||||
```bash
|
||||
# Install development dependencies
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Run tests
|
||||
pytest
|
||||
|
||||
# Run with auto-reload
|
||||
RELOAD=true python -m memory.api.app
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
- **FastAPI**: REST API and MCP server
|
||||
- **PostgreSQL**: Primary database for metadata and users
|
||||
- **Qdrant**: Vector database for semantic search
|
||||
- **RabbitMQ**: Message queue for background processing
|
||||
- **Celery**: Distributed task processing
|
||||
- **SQLAdmin**: Admin interface for database management
|
||||
|
||||
### Configuration
|
||||
|
||||
Key environment variables:
|
||||
|
||||
- `FILE_STORAGE_DIR`: Where uploaded files are stored
|
||||
- `DB_HOST`, `DB_PORT`: Database connection
|
||||
- `QDRANT_HOST`: Vector database connection
|
||||
- `RABBITMQ_HOST`: Message queue connection
|
||||
|
||||
See `docker-compose.yaml` for full configuration options.
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Never expose the main API directly to the internet without proper authentication
|
||||
- Use the proxy for MCP connections to handle authentication securely
|
||||
- Store secrets in the `secrets/` directory (see `docker-compose.yaml`)
|
||||
- The application runs with minimal privileges in Docker containers
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Database connection errors**: Ensure PostgreSQL is running and accessible
|
||||
2. **Vector search not working**: Check that Qdrant is healthy
|
||||
3. **Background processing stalled**: Verify RabbitMQ and Celery workers are running
|
||||
4. **MCP connection issues**: Use the proxy instead of direct API access
|
||||
|
||||
### Logs
|
||||
|
||||
```bash
|
||||
# View API logs
|
||||
docker-compose logs -f api
|
||||
|
||||
# View worker logs
|
||||
docker-compose logs -f worker
|
||||
|
||||
# View all logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This is a personal knowledge base system. Feel free to fork and adapt for your own use cases.
|
49
db/migrations/versions/20250603_164859_add_user.py
Normal file
49
db/migrations/versions/20250603_164859_add_user.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Add user
|
||||
|
||||
Revision ID: 77cdbfc882e2
|
||||
Revises: 152f8b4b52e8
|
||||
Create Date: 2025-06-03 16:48:59.509683
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "77cdbfc882e2"
|
||||
down_revision: Union[str, None] = "152f8b4b52e8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("email", sa.String(), nullable=False),
|
||||
sa.Column("password_hash", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email"),
|
||||
)
|
||||
op.create_table(
|
||||
"user_sessions",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||
),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("user_sessions")
|
||||
op.drop_table("users")
|
35
dev.sh
35
dev.sh
@ -16,15 +16,48 @@ cd "$SCRIPT_DIR"
|
||||
docker volume create memory_file_storage
|
||||
docker run --rm -v memory_file_storage:/data busybox chown -R 1000:1000 /data
|
||||
|
||||
POSTGRES_PASSWORD=543218ZrHw8Pxbs3YXzaVHq8YKVHwCj6Pz8RQkl8
|
||||
echo $POSTGRES_PASSWORD > secrets/postgres_password.txt
|
||||
|
||||
# Create a temporary docker-compose override file to expose PostgreSQL
|
||||
echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}"
|
||||
if [ ! -f docker-compose.override.yml ]; then
|
||||
cat > docker-compose.override.yml << EOL
|
||||
version: "3.9"
|
||||
services:
|
||||
qdrant:
|
||||
ports:
|
||||
- "6333:6333"
|
||||
|
||||
postgres:
|
||||
ports:
|
||||
- "5432:5432"
|
||||
# PostgreSQL port for local Celery result backend
|
||||
- "15432:5432"
|
||||
|
||||
rabbitmq:
|
||||
ports:
|
||||
# UI only on localhost
|
||||
- "15672:15672"
|
||||
# AMQP port for local Celery clients (for local workers)
|
||||
- "15673:5672"
|
||||
EOL
|
||||
fi
|
||||
|
||||
if [ ! -f .env ]; then
|
||||
echo $POSTGRES_PASSWORD > .env
|
||||
cat >> .env << EOL
|
||||
CELERY_BROKER_PASSWORD=543218ZrHw8Pxbs3YXzaVHq8YKVHwCj6Pz8RQkl8
|
||||
|
||||
RABBITMQ_HOST=localhost
|
||||
QDRANT_HOST=localhost
|
||||
DB_HOST=localhost
|
||||
|
||||
VOYAGE_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
|
||||
DB_PORT=15432
|
||||
RABBITMQ_PORT=15673
|
||||
EOL
|
||||
fi
|
||||
|
||||
|
@ -9,9 +9,11 @@ networks:
|
||||
# --------------------------------------------------------------------- secrets
|
||||
secrets:
|
||||
postgres_password: { file: ./secrets/postgres_password.txt }
|
||||
jwt_secret: { file: ./secrets/jwt_secret.txt }
|
||||
openai_key: { file: ./secrets/openai_key.txt }
|
||||
anthropic_key: { file: ./secrets/anthropic_key.txt }
|
||||
ssh_private_key: { file: ./secrets/ssh_private_key }
|
||||
ssh_public_key: { file: ./secrets/ssh_public_key }
|
||||
ssh_known_hosts: { file: ./secrets/ssh_known_hosts }
|
||||
|
||||
# --------------------------------------------------------------------- volumes
|
||||
volumes:
|
||||
@ -23,6 +25,7 @@ volumes:
|
||||
x-common-env: &env
|
||||
RABBITMQ_USER: kb
|
||||
RABBITMQ_HOST: rabbitmq
|
||||
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
|
||||
QDRANT_HOST: qdrant
|
||||
DB_HOST: postgres
|
||||
DB_PORT: 5432
|
||||
@ -46,9 +49,12 @@ x-worker-base: &worker-base
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||
secrets: [ postgres_password, openai_key, anthropic_key ]
|
||||
secrets: [ postgres_password, openai_key, anthropic_key, ssh_private_key, ssh_public_key, ssh_known_hosts ]
|
||||
read_only: true
|
||||
tmpfs: [ /tmp, /var/tmp ]
|
||||
tmpfs:
|
||||
- /tmp
|
||||
- /var/tmp
|
||||
- /home/kb/.ssh:uid=1000,gid=1000,mode=700
|
||||
cap_drop: [ ALL ]
|
||||
volumes:
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
@ -77,10 +83,23 @@ services:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
mem_limit: 4g
|
||||
cpus: "1.5"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
|
||||
migrate:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/migrations.Dockerfile
|
||||
networks: [kbnet]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
<<: *env
|
||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||
secrets: [postgres_password]
|
||||
volumes:
|
||||
- ./db:/app/db:ro
|
||||
|
||||
rabbitmq:
|
||||
image: rabbitmq:3.13-management
|
||||
restart: unless-stopped
|
||||
@ -88,7 +107,7 @@ services:
|
||||
environment:
|
||||
<<: *env
|
||||
RABBITMQ_DEFAULT_USER: "kb"
|
||||
RABBITMQ_DEFAULT_PASS: "${RABBITMQ_PASSWORD}"
|
||||
RABBITMQ_DEFAULT_PASS: "${CELERY_BROKER_PASSWORD}"
|
||||
volumes:
|
||||
- rabbitmq_data:/var/lib/rabbitmq:rw
|
||||
healthcheck:
|
||||
@ -96,8 +115,6 @@ services:
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
mem_limit: 512m
|
||||
cpus: "0.5"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
|
||||
qdrant:
|
||||
@ -115,118 +132,52 @@ services:
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
mem_limit: 4g
|
||||
cpus: "2"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
cap_drop: [ ALL ]
|
||||
|
||||
# ------------------------------------------------------------ API / gateway
|
||||
# api:
|
||||
# build:
|
||||
# context: .
|
||||
# dockerfile: docker/api/Dockerfile
|
||||
# restart: unless-stopped
|
||||
# networks: [kbnet]
|
||||
# depends_on: [postgres, rabbitmq, qdrant]
|
||||
# environment:
|
||||
# <<: *env
|
||||
# JWT_SECRET_FILE: /run/secrets/jwt_secret
|
||||
# OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||
# POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||
# QDRANT_URL: http://qdrant:6333
|
||||
# secrets: [jwt_secret, openai_key, postgres_password]
|
||||
# healthcheck:
|
||||
# test: ["CMD-SHELL", "curl -fs http://localhost:8000/health || exit 1"]
|
||||
# interval: 15s
|
||||
# timeout: 5s
|
||||
# retries: 5
|
||||
# mem_limit: 768m
|
||||
# cpus: "1"
|
||||
# labels:
|
||||
# - "traefik.enable=true"
|
||||
# - "traefik.http.routers.kb.rule=Host(`${TRAEFIK_DOMAIN}`)"
|
||||
# - "traefik.http.routers.kb.entrypoints=websecure"
|
||||
# - "traefik.http.services.kb.loadbalancer.server.port=8000"
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.0
|
||||
api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/api/Dockerfile
|
||||
restart: unless-stopped
|
||||
networks: [ kbnet ]
|
||||
command:
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.network=kbnet"
|
||||
- "--entrypoints.web.address=:80"
|
||||
- "--entrypoints.websecure.address=:443"
|
||||
# - "--certificatesresolvers.le.acme.httpchallenge=true"
|
||||
# - "--certificatesresolvers.le.acme.httpchallenge.entrypoint=web"
|
||||
# - "--certificatesresolvers.le.acme.email=${LE_EMAIL}"
|
||||
# - "--certificatesresolvers.le.acme.storage=/acme.json"
|
||||
- "--log.level=INFO"
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
networks: [kbnet]
|
||||
depends_on: [postgres, rabbitmq, qdrant]
|
||||
environment:
|
||||
<<: *env
|
||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
secrets: [postgres_password]
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
# - ./acme.json:/acme.json:rw
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "curl -fs http://localhost:8000/health || exit 1"]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
# ------------------------------------------------------------ Celery workers
|
||||
worker-email:
|
||||
proxy:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/api/Dockerfile
|
||||
restart: unless-stopped
|
||||
networks: [kbnet]
|
||||
environment:
|
||||
<<: *env
|
||||
command: ["python", "/app/tools/simple_proxy.py", "--remote-server", "${PROXY_REMOTE_SERVER:-http://api:8000}", "--email", "${PROXY_EMAIL}", "--password", "${PROXY_PASSWORD}", "--port", "8001"]
|
||||
volumes:
|
||||
- ./tools:/app/tools:ro
|
||||
ports:
|
||||
- "8001:8001"
|
||||
|
||||
# ------------------------------------------------------------ Celery workers
|
||||
worker:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "email"
|
||||
# deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
|
||||
|
||||
worker-text:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "medium_embed"
|
||||
|
||||
worker-ebook:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "ebooks"
|
||||
|
||||
worker-comic:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "comic"
|
||||
|
||||
worker-blogs:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "blogs"
|
||||
|
||||
worker-forums:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "forums"
|
||||
|
||||
worker-photo:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "photo_embed,comic"
|
||||
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
|
||||
|
||||
worker-notes:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "notes"
|
||||
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
|
||||
|
||||
worker-maintenance:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "maintenance"
|
||||
# deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
||||
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes"
|
||||
|
||||
ingest-hub:
|
||||
<<: *worker-base
|
||||
@ -245,35 +196,9 @@ services:
|
||||
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
||||
|
||||
# ------------------------------------------------------------ watchtower (auto-update)
|
||||
watchtower:
|
||||
image: containrrr/watchtower
|
||||
restart: unless-stopped
|
||||
command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
|
||||
volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
|
||||
networks: [ kbnet ]
|
||||
|
||||
# ------------------------------------------------------------------- profiles: observability (opt-in)
|
||||
# services:
|
||||
# prometheus:
|
||||
# image: prom/prometheus:v2.52
|
||||
# profiles: ["obs"]
|
||||
# networks: [kbnet]
|
||||
# volumes: [./observability/prometheus.yml:/etc/prometheus/prometheus.yml:ro]
|
||||
# restart: unless-stopped
|
||||
# ports: ["127.0.0.1:9090:9090"]
|
||||
|
||||
# grafana:
|
||||
# image: grafana/grafana:10
|
||||
# profiles: ["obs"]
|
||||
# networks: [kbnet]
|
||||
# volumes: [./observability/grafana:/var/lib/grafana]
|
||||
# restart: unless-stopped
|
||||
# environment:
|
||||
# GF_SECURITY_ADMIN_USER: admin
|
||||
# GF_SECURITY_ADMIN_PASSWORD_FILE: /run/secrets/grafana_pw
|
||||
# secrets: [grafana_pw]
|
||||
# ports: ["127.0.0.1:3000:3000"]
|
||||
|
||||
# secrets: # extra secret for Grafana, not needed otherwise
|
||||
# grafana_pw:
|
||||
# file: ./secrets/grafana_pw.txt
|
||||
# watchtower:
|
||||
# image: containrrr/watchtower
|
||||
# restart: unless-stopped
|
||||
# command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
|
||||
# volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
|
||||
# networks: [ kbnet ]
|
@ -1,24 +1,39 @@
|
||||
FROM python:3.10-slim
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements files and setup
|
||||
COPY requirements-*.txt ./
|
||||
COPY setup.py ./
|
||||
COPY src/ ./src/
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
python3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install the package with API dependencies
|
||||
# Copy requirements files and setup
|
||||
COPY requirements ./requirements/
|
||||
RUN mkdir src
|
||||
COPY setup.py ./
|
||||
# Do an initial install to get the dependencies cached
|
||||
RUN pip install -e ".[api]"
|
||||
|
||||
# Install the package with common dependencies
|
||||
COPY src/ ./src/
|
||||
RUN pip install -e ".[api]"
|
||||
|
||||
# Run as non-root user
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
RUN mkdir -p /app/memory_files
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
# Create user and set permissions
|
||||
RUN useradd -m kb
|
||||
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
|
||||
|
||||
USER kb
|
||||
|
||||
# Set environment variables
|
||||
ENV PORT=8000
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the API
|
||||
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
@ -5,17 +5,16 @@ WORKDIR /app
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libpq-dev gcc supervisor && \
|
||||
pip install -e ".[workers]" && \
|
||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements-*.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements-common.txt
|
||||
RUN pip install --no-cache-dir -r requirements-parsers.txt
|
||||
RUN pip install --no-cache-dir -r requirements-workers.txt
|
||||
|
||||
# Copy requirements files and setup
|
||||
COPY requirements ./requirements/
|
||||
COPY setup.py ./
|
||||
RUN mkdir src
|
||||
RUN pip install -e ".[common]"
|
||||
|
||||
COPY src/ ./src/
|
||||
RUN pip install -e ".[common]"
|
||||
|
||||
# Create and copy entrypoint script
|
||||
COPY docker/workers/entry.sh ./entry.sh
|
||||
|
28
docker/migrations.Dockerfile
Normal file
28
docker/migrations.Dockerfile
Normal file
@ -0,0 +1,28 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements files and setup
|
||||
COPY requirements ./requirements/
|
||||
COPY setup.py ./
|
||||
RUN mkdir src
|
||||
RUN pip install -e ".[common]"
|
||||
|
||||
# Install the package with common dependencies
|
||||
COPY src/ ./src/
|
||||
RUN pip install -e ".[common]"
|
||||
|
||||
# Run as non-root user
|
||||
RUN useradd -m appuser
|
||||
RUN mkdir -p /app/memory_files
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
# Create user and set permissions
|
||||
RUN useradd -m kb
|
||||
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
|
||||
|
||||
USER kb
|
||||
|
||||
# Run the migrations
|
||||
CMD ["alembic", "-c", "/app/db/migrations/alembic.ini", "upgrade", "head"]
|
@ -3,7 +3,7 @@ FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libpq-dev gcc pandoc \
|
||||
libpq-dev gcc pandoc git openssh-client \
|
||||
texlive-xetex texlive-fonts-recommended texlive-plain-generic \
|
||||
texlive-lang-greek texlive-lang-cyrillic texlive-lang-european \
|
||||
texlive-luatex texlive-latex-extra texlive-latex-recommended \
|
||||
@ -13,13 +13,12 @@ RUN apt-get update && apt-get install -y \
|
||||
# libreoffice-writer \
|
||||
&& apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements-*.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements-common.txt
|
||||
RUN pip install --no-cache-dir -r requirements-parsers.txt
|
||||
RUN pip install --no-cache-dir -r requirements-workers.txt
|
||||
COPY requirements ./requirements/
|
||||
COPY setup.py ./
|
||||
RUN mkdir src
|
||||
RUN pip install -e ".[common]"
|
||||
|
||||
# Install Python dependencies
|
||||
COPY setup.py ./
|
||||
COPY src/ ./src/
|
||||
RUN pip install -e ".[workers]"
|
||||
|
||||
@ -32,12 +31,18 @@ RUN mkdir -p /app/memory_files
|
||||
COPY docker/workers/unnest-table.lua ./unnest-table.lua
|
||||
|
||||
# Create user and set permissions
|
||||
RUN useradd -m kb
|
||||
RUN useradd -m -u 1000 kb
|
||||
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
|
||||
|
||||
USER kb
|
||||
|
||||
# Git config will be set via environment variables
|
||||
ENV GIT_USER_EMAIL=${GIT_USER_EMAIL:-me@some.domain}
|
||||
ENV GIT_USER_NAME=${GIT_USER_NAME:-memory}
|
||||
RUN git config --global user.email "${GIT_USER_EMAIL}" && \
|
||||
git config --global user.name "${GIT_USER_NAME}"
|
||||
|
||||
# Default queues to process
|
||||
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
@ -1,7 +1,22 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# SSH Setup for git operations
|
||||
if [ -f /run/secrets/ssh_private_key ]; then
|
||||
echo "Setting up SSH keys for git operations..."
|
||||
mkdir -p ~/.ssh
|
||||
cp /run/secrets/ssh_private_key ~/.ssh/id_rsa
|
||||
cp /run/secrets/ssh_public_key ~/.ssh/id_rsa.pub
|
||||
cp /run/secrets/ssh_known_hosts ~/.ssh/known_hosts
|
||||
chmod 700 ~/.ssh
|
||||
chmod 600 ~/.ssh/id_rsa
|
||||
chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/known_hosts
|
||||
echo "SSH keys configured successfully"
|
||||
fi
|
||||
|
||||
QUEUE_PREFIX=${QUEUE_PREFIX:-memory}
|
||||
QUEUES=${QUEUES:-default}
|
||||
QUEUES=$(IFS=,; echo "${QUEUES}" | tr ',' '\n' | sed "s/^/${QUEUE_PREFIX}-/" | paste -sd, -)
|
||||
CONCURRENCY=${CONCURRENCY:-2}
|
||||
LOGLEVEL=${LOGLEVEL:-INFO}
|
||||
|
||||
|
@ -1,5 +0,0 @@
|
||||
PyMuPDF==1.25.5
|
||||
ebooklib==0.18.0
|
||||
beautifulsoup4==4.13.4
|
||||
markdownify==0.13.1
|
||||
pillow==10.4.0
|
@ -1,5 +0,0 @@
|
||||
openai==1.25.0
|
||||
pillow==10.4.0
|
||||
pypandoc==1.15.0
|
||||
beautifulsoup4==4.13.4
|
||||
feedparser==6.0.10
|
@ -2,6 +2,6 @@ fastapi==0.112.2
|
||||
uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
sqladmin
|
||||
sqladmin==0.20.1
|
||||
mcp==1.9.2
|
||||
bm25s[full]==0.2.13
|
@ -1,11 +1,12 @@
|
||||
sqlalchemy==2.0.30
|
||||
psycopg2-binary==2.9.9
|
||||
pydantic==2.7.1
|
||||
pydantic==2.7.2
|
||||
alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.18.1
|
||||
openai==1.25.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
celery==5.3.6
|
||||
celery[sqs]==5.3.6
|
7
requirements/requirements-parsers.txt
Normal file
7
requirements/requirements-parsers.txt
Normal file
@ -0,0 +1,7 @@
|
||||
PyMuPDF==1.25.5
|
||||
EbookLib==0.18
|
||||
beautifulsoup4==4.13.4
|
||||
markdownify==0.13.1
|
||||
pillow==10.3.0
|
||||
pypandoc==1.15
|
||||
feedparser==6.0.10
|
14
setup.py
14
setup.py
@ -4,7 +4,7 @@ from setuptools import setup, find_namespace_packages
|
||||
|
||||
def read_requirements(filename: str) -> list[str]:
|
||||
"""Read requirements from file, ignoring comments and -r directives."""
|
||||
path = pathlib.Path(filename)
|
||||
path = pathlib.Path(__file__).parent / "requirements" / filename
|
||||
return [
|
||||
line.strip()
|
||||
for line in path.read_text().splitlines()
|
||||
@ -16,7 +16,6 @@ def read_requirements(filename: str) -> list[str]:
|
||||
common_requires = read_requirements("requirements-common.txt")
|
||||
parsers_requires = read_requirements("requirements-parsers.txt")
|
||||
api_requires = read_requirements("requirements-api.txt")
|
||||
workers_requires = read_requirements("requirements-workers.txt")
|
||||
dev_requires = read_requirements("requirements-dev.txt")
|
||||
|
||||
setup(
|
||||
@ -26,14 +25,9 @@ setup(
|
||||
packages=find_namespace_packages(where="src"),
|
||||
python_requires=">=3.10",
|
||||
extras_require={
|
||||
"api": api_requires + common_requires,
|
||||
"workers": workers_requires + common_requires + parsers_requires,
|
||||
"common": common_requires,
|
||||
"api": api_requires + common_requires + parsers_requires,
|
||||
"common": common_requires + parsers_requires,
|
||||
"dev": dev_requires,
|
||||
"all": api_requires
|
||||
+ workers_requires
|
||||
+ common_requires
|
||||
+ dev_requires
|
||||
+ parsers_requires,
|
||||
"all": api_requires + common_requires + dev_requires + parsers_requires,
|
||||
},
|
||||
)
|
||||
|
@ -5,6 +5,7 @@ MCP tools for the epistemic sparring partner system.
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
import mimetypes
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from pydantic import BaseModel
|
||||
@ -182,6 +183,24 @@ async def observe(
|
||||
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
|
||||
Be specific and detailed - observations should make sense months later.
|
||||
|
||||
Example call:
|
||||
```
|
||||
{
|
||||
"observations": [
|
||||
{
|
||||
"content": "The user is a software engineer.",
|
||||
"subject": "user",
|
||||
"observation_type": "belief",
|
||||
"confidences": {"observation_accuracy": 0.9},
|
||||
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
|
||||
"tags": ["programming", "work"]
|
||||
}
|
||||
],
|
||||
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"agent_model": "gpt-4o"
|
||||
}
|
||||
```
|
||||
|
||||
RawObservation fields:
|
||||
content (required): Detailed observation text explaining what you observed
|
||||
subject (required): Consistent identifier like "programming_style", "work_habits"
|
||||
@ -200,7 +219,7 @@ async def observe(
|
||||
observation,
|
||||
celery_app.send_task(
|
||||
SYNC_OBSERVATION,
|
||||
queue="notes",
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": observation.subject,
|
||||
"content": observation.content,
|
||||
@ -323,7 +342,7 @@ async def create_note(
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
SYNC_NOTE,
|
||||
queue="notes",
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": subject,
|
||||
"content": content,
|
||||
@ -367,20 +386,67 @@ async def note_files(path: str = "/"):
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def fetch_file(filename: str):
|
||||
def fetch_file(filename: str) -> dict:
|
||||
"""
|
||||
Read file content from user's storage.
|
||||
Use when you need to access specific content of a file that's been referenced.
|
||||
|
||||
Args:
|
||||
filename: Path to file (e.g., "/notes/project.md", "/documents/report.pdf")
|
||||
Path should start with "/" and use forward slashes
|
||||
|
||||
Returns: Raw bytes content (decode as UTF-8 for text files)
|
||||
Raises FileNotFoundError if file doesn't exist.
|
||||
Read file content with automatic type detection.
|
||||
Returns dict with content, mime_type, is_text, file_size.
|
||||
Text content as string, binary as base64.
|
||||
"""
|
||||
path = settings.FILE_STORAGE_DIR / filename.lstrip("/")
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
|
||||
return path.read_bytes()
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
text_extensions = {
|
||||
".md",
|
||||
".txt",
|
||||
".py",
|
||||
".js",
|
||||
".html",
|
||||
".css",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
}
|
||||
text_mimes = {
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
}
|
||||
is_text = (
|
||||
mime_type.startswith("text/")
|
||||
or mime_type in text_mimes
|
||||
or path.suffix.lower() in text_extensions
|
||||
)
|
||||
|
||||
try:
|
||||
content = (
|
||||
path.read_text(encoding="utf-8")
|
||||
if is_text
|
||||
else __import__("base64").b64encode(path.read_bytes()).decode("ascii")
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
import base64
|
||||
|
||||
content = base64.b64encode(path.read_bytes()).decode("ascii")
|
||||
is_text = False
|
||||
mime_type = (
|
||||
"application/octet-stream" if mime_type.startswith("text/") else mime_type
|
||||
)
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"mime_type": mime_type,
|
||||
"is_text": is_text,
|
||||
"file_size": path.stat().st_size,
|
||||
"filename": filename,
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ from memory.common.db.models import (
|
||||
ForumPost,
|
||||
AgentObservation,
|
||||
Note,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
@ -207,6 +208,15 @@ class NoteAdmin(ModelView, model=Note):
|
||||
column_sortable_list = ["inserted_at"]
|
||||
|
||||
|
||||
class UserAdmin(ModelView, model=User):
|
||||
column_list = [
|
||||
"id",
|
||||
"email",
|
||||
"name",
|
||||
"created_at",
|
||||
]
|
||||
|
||||
|
||||
def setup_admin(admin: Admin):
|
||||
"""Add all admin views to the admin instance."""
|
||||
admin.add_view(SourceItemAdmin)
|
||||
@ -224,3 +234,4 @@ def setup_admin(admin: Admin):
|
||||
admin.add_view(ForumPostAdmin)
|
||||
admin.add_view(ComicAdmin)
|
||||
admin.add_view(PhotoAdmin)
|
||||
admin.add_view(UserAdmin)
|
||||
|
@ -8,16 +8,30 @@ import pathlib
|
||||
import logging
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Query, Form
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
File,
|
||||
UploadFile,
|
||||
Query,
|
||||
Form,
|
||||
Depends,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
from sqladmin import Admin
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import extract
|
||||
from memory.common.db.connection import get_engine
|
||||
from memory.common.db.models import User
|
||||
from memory.api.admin import setup_admin
|
||||
from memory.api.search import search, SearchResult
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.auth import (
|
||||
router as auth_router,
|
||||
get_current_user,
|
||||
AuthenticationMiddleware,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,6 +49,10 @@ app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
setup_admin(admin)
|
||||
|
||||
# Include auth router
|
||||
app.add_middleware(AuthenticationMiddleware)
|
||||
app.include_router(auth_router)
|
||||
app.mount("/", mcp.streamable_http_app())
|
||||
|
||||
|
||||
@ -63,6 +81,7 @@ async def search_endpoint(
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
min_text_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Search endpoint - delegates to search module"""
|
||||
upload_data = [
|
||||
@ -74,7 +93,7 @@ async def search_endpoint(
|
||||
return await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
modalities=set(modalities),
|
||||
limit=limit,
|
||||
min_text_score=min_text_score,
|
||||
min_multimodal_score=min_multimodal_score,
|
||||
@ -82,7 +101,7 @@ async def search_endpoint(
|
||||
|
||||
|
||||
@app.get("/files/{path:path}")
|
||||
def get_file_by_path(path: str):
|
||||
def get_file_by_path(path: str, current_user: User = Depends(get_current_user)):
|
||||
"""
|
||||
Fetch a file by its path
|
||||
|
||||
|
228
src/memory/api/auth.py
Normal file
228
src/memory/api/auth.py
Normal file
@ -0,0 +1,228 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import textwrap
|
||||
from typing import cast
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
||||
from fastapi.responses import HTMLResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from memory.common import settings
|
||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common.db.connection import get_session, make_session
|
||||
from memory.common.db.models.users import User, UserSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
name: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
session_id: str
|
||||
user_id: int
|
||||
email: str
|
||||
name: str
|
||||
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
def create_user_session(
|
||||
user_id: int, db: DBSession, valid_for: int = settings.SESSION_VALID_FOR
|
||||
) -> str:
|
||||
"""Create a new session for a user"""
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=valid_for)
|
||||
|
||||
session = UserSession(user_id=user_id, expires_at=expires_at)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
|
||||
return str(session.id)
|
||||
|
||||
|
||||
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
|
||||
"""Get user from session ID if session is valid"""
|
||||
session = db.query(UserSession).get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
now = datetime.now(timezone.utc)
|
||||
if session.expires_at.replace(tzinfo=timezone.utc) > now:
|
||||
return session.user
|
||||
return None
|
||||
|
||||
|
||||
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
||||
"""FastAPI dependency to get current authenticated user"""
|
||||
# Check for session ID in header or cookie
|
||||
session_id = request.headers.get(
|
||||
settings.SESSION_HEADER_NAME
|
||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=401, detail="No session provided")
|
||||
|
||||
user = get_session_user(session_id, db)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def create_user(email: str, password: str, name: str, db: DBSession) -> User:
|
||||
"""Create a new user"""
|
||||
# Check if user already exists
|
||||
existing_user = db.query(User).filter(User.email == email).first()
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="User already exists")
|
||||
|
||||
user = User.create_with_password(email, name, password)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def authenticate_user(email: str, password: str, db: DBSession) -> User | None:
|
||||
"""Authenticate a user by email and password"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
if user and user.is_valid_password(password):
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Auth endpoints
|
||||
@router.post("/register", response_model=LoginResponse)
|
||||
def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
|
||||
"""Register a new user"""
|
||||
if not settings.REGISTER_ENABLED:
|
||||
raise HTTPException(status_code=403, detail="Registration is disabled")
|
||||
|
||||
user = create_user(request.email, request.password, request.name, db)
|
||||
session_id = create_user_session(user.id, db) # type: ignore
|
||||
|
||||
return LoginResponse(session_id=session_id, **user.serialize())
|
||||
|
||||
|
||||
@router.get("/login", response_model=LoginResponse)
|
||||
def login_page():
|
||||
"""Login page"""
|
||||
return HTMLResponse(
|
||||
content=textwrap.dedent("""
|
||||
<html>
|
||||
<body>
|
||||
<h1>Login</h1>
|
||||
<form method="post" action="/auth/login-form">
|
||||
<input type="email" name="email" placeholder="Email" />
|
||||
<input type="password" name="password" placeholder="Password" />
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
"""),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(
|
||||
request: LoginRequest, response: Response, db: DBSession = Depends(get_session)
|
||||
):
|
||||
"""Login and create a session"""
|
||||
return login_form(response, db, request.email, request.password)
|
||||
|
||||
|
||||
@router.post("/login-form", response_model=LoginResponse)
|
||||
def login_form(
|
||||
response: Response,
|
||||
db: DBSession = Depends(get_session),
|
||||
email: str = Form(),
|
||||
password: str = Form(),
|
||||
):
|
||||
"""Login with form data and create a session"""
|
||||
user = authenticate_user(email, password, db)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
session_id = create_user_session(cast(int, user.id), db)
|
||||
|
||||
# Set session cookie
|
||||
response.set_cookie(
|
||||
key=settings.SESSION_COOKIE_NAME,
|
||||
value=session_id,
|
||||
httponly=True,
|
||||
secure=settings.HTTPS,
|
||||
samesite="lax",
|
||||
max_age=settings.SESSION_COOKIE_MAX_AGE,
|
||||
)
|
||||
|
||||
return LoginResponse(session_id=session_id, **user.serialize())
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(response: Response, user: User = Depends(get_current_user)):
|
||||
"""Logout and clear session"""
|
||||
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def get_me(user: User = Depends(get_current_user)):
|
||||
"""Get current user info"""
|
||||
return user.serialize()
|
||||
|
||||
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to require authentication for all endpoints except whitelisted ones."""
|
||||
|
||||
# Endpoints that don't require authentication
|
||||
WHITELIST = {
|
||||
"/health",
|
||||
"/auth/login",
|
||||
"/auth/login-form",
|
||||
"/auth/register",
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
path = request.url.path
|
||||
|
||||
# Skip authentication for whitelisted endpoints
|
||||
if any(path.startswith(whitelist_path) for whitelist_path in self.WHITELIST):
|
||||
return await call_next(request)
|
||||
|
||||
# Check for session ID in header or cookie
|
||||
session_id = request.headers.get(
|
||||
settings.SESSION_HEADER_NAME
|
||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
if not session_id:
|
||||
return Response(
|
||||
content="Authentication required",
|
||||
status_code=401,
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Validate session and get user
|
||||
with make_session() as session:
|
||||
user = get_session_user(session_id, session)
|
||||
if not user:
|
||||
return Response(
|
||||
content="Invalid or expired session",
|
||||
status_code=401,
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
logger.debug(f"Authenticated request from user {user.email} to {path}")
|
||||
|
||||
return await call_next(request)
|
@ -1,4 +1,5 @@
|
||||
from celery import Celery
|
||||
from kombu.utils.url import safequote
|
||||
from memory.common import settings
|
||||
|
||||
EMAIL_ROOT = "memory.workers.tasks.email"
|
||||
@ -13,6 +14,7 @@ OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||
|
||||
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
||||
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
||||
SETUP_GIT_NOTES = f"{NOTES_ROOT}.setup_git_notes"
|
||||
SYNC_OBSERVATION = f"{OBSERVATIONS_ROOT}.sync_observation"
|
||||
SYNC_ALL_COMICS = f"{COMIC_ROOT}.sync_all_comics"
|
||||
SYNC_SMBC = f"{COMIC_ROOT}.sync_smbc"
|
||||
@ -41,13 +43,17 @@ SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
|
||||
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
||||
|
||||
|
||||
def rabbit_url() -> str:
|
||||
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:{settings.RABBITMQ_PORT}//"
|
||||
def get_broker_url() -> str:
|
||||
protocol = settings.CELERY_BROKER_TYPE
|
||||
user = safequote(settings.CELERY_BROKER_USER)
|
||||
password = safequote(settings.CELERY_BROKER_PASSWORD)
|
||||
host = settings.CELERY_BROKER_HOST
|
||||
return f"{protocol}://{user}:{password}@{host}"
|
||||
|
||||
|
||||
app = Celery(
|
||||
"memory",
|
||||
broker=rabbit_url(),
|
||||
broker=get_broker_url(),
|
||||
backend=settings.CELERY_RESULT_BACKEND,
|
||||
)
|
||||
|
||||
@ -59,20 +65,22 @@ app.conf.update(
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
task_routes={
|
||||
f"{EMAIL_ROOT}.*": {"queue": "email"},
|
||||
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"},
|
||||
f"{COMIC_ROOT}.*": {"queue": "comic"},
|
||||
f"{EBOOK_ROOT}.*": {"queue": "ebooks"},
|
||||
f"{BLOGS_ROOT}.*": {"queue": "blogs"},
|
||||
f"{FORUMS_ROOT}.*": {"queue": "forums"},
|
||||
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"},
|
||||
f"{NOTES_ROOT}.*": {"queue": "notes"},
|
||||
f"{OBSERVATIONS_ROOT}.*": {"queue": "notes"},
|
||||
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
||||
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
||||
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
||||
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
||||
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
||||
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
|
||||
f"{MAINTENANCE_ROOT}.*": {
|
||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
|
||||
},
|
||||
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.on_after_configure.connect # type: ignore
|
||||
@app.on_after_configure.connect # type: ignore[attr-defined]
|
||||
def ensure_qdrant_initialised(sender, **_):
|
||||
from memory.common import qdrant
|
||||
|
||||
|
@ -3,8 +3,9 @@ Database connection utilities.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, Session
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
@ -23,9 +24,17 @@ def get_session_factory():
|
||||
|
||||
def get_scoped_session():
|
||||
"""Create a thread-local scoped session factory"""
|
||||
engine = get_engine()
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
return scoped_session(session_factory)
|
||||
return scoped_session(get_session_factory())
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""FastAPI dependency for database sessions"""
|
||||
session_factory = get_session_factory()
|
||||
session = session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -32,6 +32,10 @@ from memory.common.db.models.sources import (
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
)
|
||||
from memory.common.db.models.users import (
|
||||
User,
|
||||
UserSession,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
@ -62,4 +66,7 @@ __all__ = [
|
||||
"Book",
|
||||
"ArticleFeed",
|
||||
"EmailAccount",
|
||||
# Users
|
||||
"User",
|
||||
"UserSession",
|
||||
]
|
||||
|
@ -269,7 +269,7 @@ class Comic(SourceItem):
|
||||
return {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
image = Image.open(pathlib.Path(cast(str, self.filename)))
|
||||
image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename))
|
||||
description = f"{self.title} by {self.author}"
|
||||
return [extract.DataChunk(data=[image, description])]
|
||||
|
||||
|
65
src/memory/common/db/models/users.py
Normal file
65
src/memory/common/db/models/users.py
Normal file
@ -0,0 +1,65 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import cast
|
||||
import uuid
|
||||
from memory.common.db.models.base import Base
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using SHA-256 with salt"""
|
||||
salt = secrets.token_hex(16)
|
||||
return f"{salt}:{hashlib.sha256((salt + password).encode()).hexdigest()}"
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
try:
|
||||
salt, hash_value = password_hash.split(":", 1)
|
||||
return hashlib.sha256((salt + password).encode()).hexdigest() == hash_value
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
email = Column(String, nullable=False, unique=True)
|
||||
password_hash = Column(String, nullable=False)
|
||||
|
||||
# Relationship to sessions
|
||||
sessions = relationship(
|
||||
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
"user_id": self.id,
|
||||
"name": self.name,
|
||||
"email": self.email,
|
||||
}
|
||||
|
||||
def is_valid_password(self, password: str) -> bool:
|
||||
"""Check if the provided password is valid for this user"""
|
||||
return verify_password(password, cast(str, self.password_hash))
|
||||
|
||||
@classmethod
|
||||
def create_with_password(cls, email: str, name: str, password: str) -> "User":
|
||||
"""Create a new user with a hashed password"""
|
||||
return cls(email=email, name=name, password_hash=hash_password(password))
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
# Relationship to user
|
||||
user = relationship("User", back_populates="sessions")
|
@ -29,11 +29,18 @@ def make_db_url(
|
||||
|
||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||
|
||||
# Celery settings
|
||||
RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb")
|
||||
RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb")
|
||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||
|
||||
# Broker settings
|
||||
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "amqp").lower() # amqp or sqs
|
||||
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "kb")
|
||||
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", "kb")
|
||||
|
||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
||||
if not CELERY_BROKER_HOST and CELERY_BROKER_TYPE == "amqp":
|
||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
||||
|
||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||
|
||||
@ -122,3 +129,12 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
||||
else:
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
|
||||
# API settings
|
||||
HTTPS = boolean_env("HTTPS", False)
|
||||
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||
|
||||
REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True
|
||||
|
@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from memory.common import settings, chunker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -12,11 +13,13 @@ The following text is already concise. Please identify 3-5 relevant tags that ca
|
||||
|
||||
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "{summary}",
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}}
|
||||
Return your response as XML with this format:
|
||||
<summary>{summary}</summary>
|
||||
<tags>
|
||||
<tag>tag1</tag>
|
||||
<tag>tag2</tag>
|
||||
<tag>tag3</tag>
|
||||
</tags>
|
||||
|
||||
Text:
|
||||
{content}
|
||||
@ -28,17 +31,28 @@ Also provide 3-5 relevant tags that capture the main topics or themes.
|
||||
|
||||
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "your summary here",
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}}
|
||||
Return your response as XML with this format:
|
||||
|
||||
<summary>your summary here</summary>
|
||||
<tags>
|
||||
<tag>tag1</tag>
|
||||
<tag>tag2</tag>
|
||||
<tag>tag3</tag>
|
||||
</tags>
|
||||
|
||||
Text to summarize:
|
||||
{content}
|
||||
"""
|
||||
|
||||
|
||||
def parse_response(response: str) -> dict[str, Any]:
|
||||
"""Parse the response from the summarizer."""
|
||||
soup = BeautifulSoup(response, "xml")
|
||||
summary = soup.find("summary").text
|
||||
tags = [tag.text for tag in soup.find_all("tag")]
|
||||
return {"summary": summary, "tags": tags}
|
||||
|
||||
|
||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
@ -58,7 +72,7 @@ def _call_openai(prompt: str) -> dict[str, Any]:
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return json.loads(response.choices[0].message.content or "{}")
|
||||
return parse_response(response.choices[0].message.content or "")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
@ -73,13 +87,14 @@ def _call_anthropic(prompt: str) -> dict[str, Any]:
|
||||
response = client.messages.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid JSON.",
|
||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return json.loads(response.content[0].text)
|
||||
return parse_response(response.content[0].text)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
logger.error(response.content[0].text)
|
||||
raise
|
||||
|
||||
|
||||
|
@ -294,4 +294,5 @@ feeds = [
|
||||
"https://www.theredhandfiles.com/",
|
||||
"https://karlin.blog/",
|
||||
"https://slatestarcodex.com/",
|
||||
"https://www.astralcodexten.com/",
|
||||
]
|
||||
|
@ -75,6 +75,7 @@ def sync_comic(
|
||||
published_date: datetime | None = None,
|
||||
):
|
||||
"""Synchronize a comic from a URL."""
|
||||
logger.info(f"syncing comic {url}")
|
||||
with make_session() as session:
|
||||
existing_comic = check_content_exists(session, Comic, url=url)
|
||||
if existing_comic:
|
||||
@ -101,7 +102,7 @@ def sync_comic(
|
||||
url=url,
|
||||
published=published_date,
|
||||
author=author,
|
||||
filename=filename.resolve().as_posix(),
|
||||
filename=filename.resolve().relative_to(settings.FILE_STORAGE_DIR).as_posix(),
|
||||
mime_type=mime_type,
|
||||
size=len(response.content),
|
||||
sha256=create_content_hash(f"{image_url}{published_date}"),
|
||||
|
@ -1,9 +1,13 @@
|
||||
import logging
|
||||
import pathlib
|
||||
import contextlib
|
||||
import subprocess
|
||||
import shlex
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Note
|
||||
from memory.common.celery_app import app, SYNC_NOTE, SYNC_NOTES
|
||||
from memory.common.celery_app import app, SYNC_NOTE, SYNC_NOTES, SETUP_GIT_NOTES
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
@ -15,6 +19,41 @@ from memory.workers.tasks.content_processing import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def git_command(repo_root: pathlib.Path, *args: str, force: bool = False):
|
||||
if not (repo_root / ".git").exists() and not force:
|
||||
return
|
||||
|
||||
# Properly escape arguments for shell execution
|
||||
escaped_args = [shlex.quote(arg) for arg in args]
|
||||
cmd = f"git -C {shlex.quote(repo_root.as_posix())} {' '.join(escaped_args)}"
|
||||
|
||||
res = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
capture_output=True, # Capture both stdout and stderr
|
||||
)
|
||||
if res.returncode != 0:
|
||||
logger.error(f"Git command failed: {res.returncode}")
|
||||
logger.error(f"stderr: {res.stderr}")
|
||||
if res.stdout:
|
||||
logger.error(f"stdout: {res.stdout}")
|
||||
return res
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def git_tracking(repo_root: pathlib.Path, commit_message: str = "Sync note"):
|
||||
git_command(repo_root, "fetch")
|
||||
git_command(repo_root, "reset", "--hard", "origin/master")
|
||||
git_command(repo_root, "clean", "-fd")
|
||||
|
||||
yield
|
||||
|
||||
git_command(repo_root, "add", ".")
|
||||
git_command(repo_root, "commit", "-m", commit_message)
|
||||
git_command(repo_root, "push")
|
||||
|
||||
|
||||
@app.task(name=SYNC_NOTE)
|
||||
@safe_task_execution
|
||||
def sync_note(
|
||||
@ -62,7 +101,10 @@ def sync_note(
|
||||
note.tags = tags # type: ignore
|
||||
|
||||
note.update_confidences(confidences)
|
||||
note.save_to_file()
|
||||
with git_tracking(
|
||||
settings.NOTES_STORAGE_DIR, f"Sync note {filename}: {subject}"
|
||||
):
|
||||
note.save_to_file()
|
||||
return process_content_item(note, session)
|
||||
|
||||
|
||||
@ -88,3 +130,21 @@ def sync_notes(folder: str):
|
||||
"notes_num": len(all_files),
|
||||
"new_notes": new_notes,
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=SETUP_GIT_NOTES)
|
||||
@safe_task_execution
|
||||
def setup_git_notes(origin: str, email: str, name: str):
|
||||
logger.info(f"Setting up git notes in {origin}")
|
||||
if (settings.NOTES_STORAGE_DIR / ".git").exists():
|
||||
logger.info("Git notes already setup")
|
||||
return {"status": "already_setup"}
|
||||
|
||||
git_command(settings.NOTES_STORAGE_DIR, "init", "-b", "main", force=True)
|
||||
git_command(settings.NOTES_STORAGE_DIR, "config", "user.email", email)
|
||||
git_command(settings.NOTES_STORAGE_DIR, "config", "user.name", name)
|
||||
git_command(settings.NOTES_STORAGE_DIR, "remote", "add", "origin", origin)
|
||||
git_command(settings.NOTES_STORAGE_DIR, "add", ".")
|
||||
git_command(settings.NOTES_STORAGE_DIR, "commit", "-m", "Initial commit")
|
||||
git_command(settings.NOTES_STORAGE_DIR, "push", "-u", "origin", "main")
|
||||
return {"status": "success"}
|
||||
|
22
tools/add_user.py
Normal file
22
tools/add_user.py
Normal file
@ -0,0 +1,22 @@
|
||||
#! /usr/bin/env python
|
||||
|
||||
import argparse
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.users import User
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--email", type=str, required=True)
|
||||
args.add_argument("--password", type=str, required=True)
|
||||
args.add_argument("--name", type=str, required=True)
|
||||
args = args.parse_args()
|
||||
|
||||
with make_session() as session:
|
||||
user = User.create_with_password(
|
||||
email=args.email, password=args.password, name=args.name
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
print(f"User {args.email} created")
|
@ -23,6 +23,7 @@ from typing import Any
|
||||
|
||||
import click
|
||||
from celery import Celery
|
||||
from memory.common import settings
|
||||
from memory.common.celery_app import (
|
||||
SYNC_ALL_ARTICLE_FEEDS,
|
||||
SYNC_ARTICLE_FEED,
|
||||
@ -47,6 +48,7 @@ from memory.common.celery_app import (
|
||||
REINGEST_MISSING_CHUNKS,
|
||||
UPDATE_METADATA_FOR_ITEM,
|
||||
UPDATE_METADATA_FOR_SOURCE_ITEMS,
|
||||
SETUP_GIT_NOTES,
|
||||
app,
|
||||
)
|
||||
|
||||
@ -82,11 +84,15 @@ TASK_MAPPINGS = {
|
||||
"sync_smbc": SYNC_SMBC,
|
||||
"sync_xkcd": SYNC_XKCD,
|
||||
"sync_comic": SYNC_COMIC,
|
||||
"full_sync_comics": "memory.workers.tasks.comic.full_sync_comic",
|
||||
},
|
||||
"forums": {
|
||||
"sync_lesswrong": SYNC_LESSWRONG,
|
||||
"sync_lesswrong_post": SYNC_LESSWRONG_POST,
|
||||
},
|
||||
"notes": {
|
||||
"setup_git_notes": SETUP_GIT_NOTES,
|
||||
},
|
||||
}
|
||||
QUEUE_MAPPINGS = {
|
||||
"email": "email",
|
||||
@ -106,7 +112,9 @@ def run_task(app: Celery, category: str, task_name: str, **kwargs) -> str:
|
||||
task_path = TASK_MAPPINGS[category][task_name]
|
||||
queue_name = QUEUE_MAPPINGS.get(category) or category
|
||||
|
||||
result = app.send_task(task_path, kwargs=kwargs, queue=queue_name)
|
||||
result = app.send_task(
|
||||
task_path, kwargs=kwargs, queue=f"{settings.CELERY_QUEUE_PREFIX}-{queue_name}"
|
||||
)
|
||||
return result.id
|
||||
|
||||
|
||||
@ -224,6 +232,23 @@ def ebook_sync_book(ctx, file_path, tags):
|
||||
execute_task(ctx, "ebook", "sync_book", file_path=file_path, tags=tags)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def notes(ctx):
|
||||
"""Notes-related tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@notes.command("setup-git-notes")
|
||||
@click.option("--origin", required=True, help="Git origin")
|
||||
@click.option("--email", required=True, help="Git email")
|
||||
@click.option("--name", required=True, help="Git name")
|
||||
@click.pass_context
|
||||
def notes_setup_git_notes(ctx, origin, email, name):
|
||||
"""Setup git notes."""
|
||||
execute_task(ctx, "notes", "setup_git_notes", origin=origin, email=email, name=name)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def maintenance(ctx):
|
||||
@ -398,6 +423,13 @@ def comic_sync_comic(ctx, image_url, title, author, published_date):
|
||||
)
|
||||
|
||||
|
||||
@comic.command("full-sync-comics")
|
||||
@click.pass_context
|
||||
def comic_full_sync_comics(ctx):
|
||||
"""Full sync comics."""
|
||||
execute_task(ctx, "comic", "full_sync_comics")
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def forums(ctx):
|
||||
@ -418,7 +450,7 @@ def forums_sync_lesswrong(ctx, since_date, min_karma, limit, cooldown, max_items
|
||||
ctx,
|
||||
"forums",
|
||||
"sync_lesswrong",
|
||||
since_date=since_date,
|
||||
since=since_date,
|
||||
min_karma=min_karma,
|
||||
limit=limit,
|
||||
cooldown=cooldown,
|
126
tools/simple_proxy.py
Normal file
126
tools/simple_proxy.py
Normal file
@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
remote_server: str
|
||||
session_header: str = "X-Session-ID"
|
||||
session_id: str | None = None
|
||||
port: int = 8080
|
||||
|
||||
|
||||
def parse_args() -> State:
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Simple HTTP proxy with authentication"
|
||||
)
|
||||
parser.add_argument("--remote-server", required=True, help="Remote server URL")
|
||||
parser.add_argument("--email", required=True, help="Email for authentication")
|
||||
parser.add_argument("--password", required=True, help="Password for authentication")
|
||||
parser.add_argument(
|
||||
"--session-header", default="X-Session-ID", help="Session header name"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8080, help="Port to run proxy on")
|
||||
return State(**vars(parser.parse_args()))
|
||||
|
||||
|
||||
state = parse_args()
|
||||
|
||||
|
||||
async def login() -> None:
|
||||
"""Login to remote server and store session ID"""
|
||||
login_url = f"{state.remote_server}/auth/login"
|
||||
login_data = {"email": state.email, "password": state.password}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(login_url, json=login_data)
|
||||
response.raise_for_status()
|
||||
|
||||
login_response = response.json()
|
||||
state.session_id = login_response["session_id"]
|
||||
print(f"Successfully logged in, session ID: {state.session_id}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
print(
|
||||
f"Login failed with status {e.response.status_code}: {e.response.text}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Login failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def proxy_request(request: Request) -> Response:
|
||||
"""Proxy request to remote server with session header"""
|
||||
if not state.session_id:
|
||||
try:
|
||||
await login()
|
||||
except Exception as e:
|
||||
print(f"Login failed: {e}")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
# Build the target URL
|
||||
target_url = f"{state.remote_server}{request.url.path}"
|
||||
if request.url.query:
|
||||
target_url += f"?{request.url.query}"
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=headers | {state.session_header: state.session_id}, # type: ignore
|
||||
content=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
# Forward response
|
||||
resp = Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
k: v.replace(state.remote_server, f"http://localhost:{state.port}")
|
||||
for k, v in response.headers.items()
|
||||
},
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
print(resp.headers)
|
||||
return resp
|
||||
|
||||
except httpx.RequestError as e:
|
||||
print(f"Request failed: {e}")
|
||||
raise HTTPException(status_code=502, detail=f"Proxy request failed: {e}")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(title="Simple Proxy")
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/{path:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
|
||||
)
|
||||
async def proxy_all(request: Request):
|
||||
"""Proxy all requests to remote server"""
|
||||
return await proxy_request(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Starting proxy server on port {state.port}")
|
||||
print(f"Proxying to: {state.remote_server}")
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=state.port)
|
Loading…
x
Reference in New Issue
Block a user