mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 05:14:43 +02:00
packageable + proxy
This commit is contained in:
parent
0551ddd30c
commit
f2c24cca3b
210
README.md
Normal file
210
README.md
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
# Memory - Personal Knowledge Base
|
||||||
|
|
||||||
|
A personal knowledge base system that ingests, indexes, and provides semantic search over various content types including emails, documents, notes, web pages, and more. Features MCP (Model Context Protocol) integration for AI assistants to access and learn from your personal data.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Multi-modal Content Ingestion**: Process emails, documents, ebooks, comics, web pages, and more
|
||||||
|
- **Semantic Search**: Vector-based search across all your content with relevance scoring
|
||||||
|
- **MCP Integration**: Direct integration with AI assistants via Model Context Protocol
|
||||||
|
- **Observation System**: AI assistants can record and search long-term observations about user preferences and patterns
|
||||||
|
- **Note Taking**: Create and organize markdown notes with full-text search
|
||||||
|
- **User Management**: Multi-user support with authentication
|
||||||
|
- **RESTful API**: Complete API for programmatic access
|
||||||
|
- **Real-time Processing**: Celery-based background processing for content ingestion
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Docker and Docker Compose
|
||||||
|
- Python 3.11+ (for tools)
|
||||||
|
|
||||||
|
### 1. Start the Development Environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository and navigate to it
|
||||||
|
cd memory
|
||||||
|
|
||||||
|
# Start the core services (PostgreSQL, RabbitMQ, Qdrant)
|
||||||
|
./dev.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
|
||||||
|
- Start PostgreSQL (exposed on port 5432)
|
||||||
|
- Start RabbitMQ with management interface
|
||||||
|
- Start Qdrant vector database
|
||||||
|
- Initialize the database schema
|
||||||
|
|
||||||
|
It will also generate secrets in `secrets` and make a basic `.env` file for you.
|
||||||
|
|
||||||
|
### 2. Start the Full Application
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start all services including API and workers
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
# Check that services are healthy
|
||||||
|
docker-compose ps
|
||||||
|
```
|
||||||
|
|
||||||
|
The API will be available at `http://localhost:8000`
|
||||||
|
|
||||||
|
The is also an admin interface at `http://localhost:8000/admin` where you can see what the database
|
||||||
|
contains.
|
||||||
|
|
||||||
|
## User Management
|
||||||
|
|
||||||
|
### Create a User
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install the package in development mode
|
||||||
|
pip install -e ".[all]"
|
||||||
|
|
||||||
|
# Create a new user
|
||||||
|
python tools/add_user.py --email user@example.com --password yourpassword --name "Your Name"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
The API uses session-based authentication. Login via:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/auth/login \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"email": "user@example.com", "password": "yourpassword"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
||||||
|
|
||||||
|
## MCP Proxy Setup
|
||||||
|
|
||||||
|
Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect:
|
||||||
|
|
||||||
|
### Start the Proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tools/simple_proxy.py \
|
||||||
|
--remote-server http://localhost:8000 \
|
||||||
|
--email user@example.com \
|
||||||
|
--password yourpassword \
|
||||||
|
--port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configure Your AI Assistant
|
||||||
|
|
||||||
|
Point your MCP-compatible AI assistant to `http://localhost:8080` instead of the direct API endpoint. The proxy will:
|
||||||
|
|
||||||
|
- Handle authentication automatically
|
||||||
|
- Forward all requests to the main API
|
||||||
|
- Add the session header to each request
|
||||||
|
|
||||||
|
### Example MCP Configuration
|
||||||
|
|
||||||
|
For Claude Desktop or other MCP clients, add to your configuration:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"memory": {
|
||||||
|
"type": "streamable-http",
|
||||||
|
"url": "http://localhost:8001/mcp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available MCP Tools
|
||||||
|
|
||||||
|
When connected via MCP, AI assistants have access to:
|
||||||
|
|
||||||
|
- `search_knowledge_base()` - Search your stored content
|
||||||
|
- `search_observations()` - Search recorded observations about you
|
||||||
|
- `observe()` - Record new observations about your preferences/behavior
|
||||||
|
- `create_note()` - Create and save notes
|
||||||
|
- `note_files()` - List existing notes
|
||||||
|
- `fetch_file()` - Read file contents
|
||||||
|
- `get_all_tags()` - Get all content tags
|
||||||
|
- `get_all_subjects()` - Get observation subjects
|
||||||
|
|
||||||
|
## Content Ingestion
|
||||||
|
|
||||||
|
### Via Workers
|
||||||
|
|
||||||
|
Content is processed asynchronously by Celery workers. Supported formats include:
|
||||||
|
|
||||||
|
- PDFs, DOCX, TXT files
|
||||||
|
- Emails (mbox, EML formats)
|
||||||
|
- Web pages (HTML)
|
||||||
|
- Ebooks (EPUB, PDF)
|
||||||
|
- Images with OCR
|
||||||
|
- And more...
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install development dependencies
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Run with auto-reload
|
||||||
|
RELOAD=true python -m memory.api.app
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
- **FastAPI**: REST API and MCP server
|
||||||
|
- **PostgreSQL**: Primary database for metadata and users
|
||||||
|
- **Qdrant**: Vector database for semantic search
|
||||||
|
- **RabbitMQ**: Message queue for background processing
|
||||||
|
- **Celery**: Distributed task processing
|
||||||
|
- **SQLAdmin**: Admin interface for database management
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Key environment variables:
|
||||||
|
|
||||||
|
- `FILE_STORAGE_DIR`: Where uploaded files are stored
|
||||||
|
- `DB_HOST`, `DB_PORT`: Database connection
|
||||||
|
- `QDRANT_HOST`: Vector database connection
|
||||||
|
- `RABBITMQ_HOST`: Message queue connection
|
||||||
|
|
||||||
|
See `docker-compose.yaml` for full configuration options.
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- Never expose the main API directly to the internet without proper authentication
|
||||||
|
- Use the proxy for MCP connections to handle authentication securely
|
||||||
|
- Store secrets in the `secrets/` directory (see `docker-compose.yaml`)
|
||||||
|
- The application runs with minimal privileges in Docker containers
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Database connection errors**: Ensure PostgreSQL is running and accessible
|
||||||
|
2. **Vector search not working**: Check that Qdrant is healthy
|
||||||
|
3. **Background processing stalled**: Verify RabbitMQ and Celery workers are running
|
||||||
|
4. **MCP connection issues**: Use the proxy instead of direct API access
|
||||||
|
|
||||||
|
### Logs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# View API logs
|
||||||
|
docker-compose logs -f api
|
||||||
|
|
||||||
|
# View worker logs
|
||||||
|
docker-compose logs -f worker
|
||||||
|
|
||||||
|
# View all logs
|
||||||
|
docker-compose logs -f
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
This is a personal knowledge base system. Feel free to fork and adapt for your own use cases.
|
49
db/migrations/versions/20250603_164859_add_user.py
Normal file
49
db/migrations/versions/20250603_164859_add_user.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""Add user
|
||||||
|
|
||||||
|
Revision ID: 77cdbfc882e2
|
||||||
|
Revises: 152f8b4b52e8
|
||||||
|
Create Date: 2025-06-03 16:48:59.509683
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "77cdbfc882e2"
|
||||||
|
down_revision: Union[str, None] = "152f8b4b52e8"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(), nullable=False),
|
||||||
|
sa.Column("email", sa.String(), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("email"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"user_sessions",
|
||||||
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("user_sessions")
|
||||||
|
op.drop_table("users")
|
35
dev.sh
35
dev.sh
@ -16,15 +16,48 @@ cd "$SCRIPT_DIR"
|
|||||||
docker volume create memory_file_storage
|
docker volume create memory_file_storage
|
||||||
docker run --rm -v memory_file_storage:/data busybox chown -R 1000:1000 /data
|
docker run --rm -v memory_file_storage:/data busybox chown -R 1000:1000 /data
|
||||||
|
|
||||||
|
POSTGRES_PASSWORD=543218ZrHw8Pxbs3YXzaVHq8YKVHwCj6Pz8RQkl8
|
||||||
|
echo $POSTGRES_PASSWORD > secrets/postgres_password.txt
|
||||||
|
|
||||||
# Create a temporary docker-compose override file to expose PostgreSQL
|
# Create a temporary docker-compose override file to expose PostgreSQL
|
||||||
echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}"
|
echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}"
|
||||||
if [ ! -f docker-compose.override.yml ]; then
|
if [ ! -f docker-compose.override.yml ]; then
|
||||||
cat > docker-compose.override.yml << EOL
|
cat > docker-compose.override.yml << EOL
|
||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
services:
|
||||||
|
qdrant:
|
||||||
|
ports:
|
||||||
|
- "6333:6333"
|
||||||
|
|
||||||
postgres:
|
postgres:
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
# PostgreSQL port for local Celery result backend
|
||||||
|
- "15432:5432"
|
||||||
|
|
||||||
|
rabbitmq:
|
||||||
|
ports:
|
||||||
|
# UI only on localhost
|
||||||
|
- "15672:15672"
|
||||||
|
# AMQP port for local Celery clients (for local workers)
|
||||||
|
- "15673:5672"
|
||||||
|
EOL
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f .env ]; then
|
||||||
|
echo $POSTGRES_PASSWORD > .env
|
||||||
|
cat >> .env << EOL
|
||||||
|
CELERY_BROKER_PASSWORD=543218ZrHw8Pxbs3YXzaVHq8YKVHwCj6Pz8RQkl8
|
||||||
|
|
||||||
|
RABBITMQ_HOST=localhost
|
||||||
|
QDRANT_HOST=localhost
|
||||||
|
DB_HOST=localhost
|
||||||
|
|
||||||
|
VOYAGE_API_KEY=
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
|
||||||
|
DB_PORT=15432
|
||||||
|
RABBITMQ_PORT=15673
|
||||||
EOL
|
EOL
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ networks:
|
|||||||
# --------------------------------------------------------------------- secrets
|
# --------------------------------------------------------------------- secrets
|
||||||
secrets:
|
secrets:
|
||||||
postgres_password: { file: ./secrets/postgres_password.txt }
|
postgres_password: { file: ./secrets/postgres_password.txt }
|
||||||
jwt_secret: { file: ./secrets/jwt_secret.txt }
|
|
||||||
openai_key: { file: ./secrets/openai_key.txt }
|
openai_key: { file: ./secrets/openai_key.txt }
|
||||||
anthropic_key: { file: ./secrets/anthropic_key.txt }
|
anthropic_key: { file: ./secrets/anthropic_key.txt }
|
||||||
|
|
||||||
@ -23,6 +22,7 @@ volumes:
|
|||||||
x-common-env: &env
|
x-common-env: &env
|
||||||
RABBITMQ_USER: kb
|
RABBITMQ_USER: kb
|
||||||
RABBITMQ_HOST: rabbitmq
|
RABBITMQ_HOST: rabbitmq
|
||||||
|
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
|
||||||
QDRANT_HOST: qdrant
|
QDRANT_HOST: qdrant
|
||||||
DB_HOST: postgres
|
DB_HOST: postgres
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
@ -81,6 +81,21 @@ services:
|
|||||||
cpus: "1.5"
|
cpus: "1.5"
|
||||||
security_opt: [ "no-new-privileges=true" ]
|
security_opt: [ "no-new-privileges=true" ]
|
||||||
|
|
||||||
|
migrate:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: docker/migrations.Dockerfile
|
||||||
|
networks: [kbnet]
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
environment:
|
||||||
|
<<: *env
|
||||||
|
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||||
|
secrets: [postgres_password]
|
||||||
|
volumes:
|
||||||
|
- ./db:/app/db:ro
|
||||||
|
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:3.13-management
|
image: rabbitmq:3.13-management
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
@ -88,7 +103,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
<<: *env
|
<<: *env
|
||||||
RABBITMQ_DEFAULT_USER: "kb"
|
RABBITMQ_DEFAULT_USER: "kb"
|
||||||
RABBITMQ_DEFAULT_PASS: "${RABBITMQ_PASSWORD}"
|
RABBITMQ_DEFAULT_PASS: "${CELERY_BROKER_PASSWORD}"
|
||||||
volumes:
|
volumes:
|
||||||
- rabbitmq_data:/var/lib/rabbitmq:rw
|
- rabbitmq_data:/var/lib/rabbitmq:rw
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@ -121,112 +136,54 @@ services:
|
|||||||
cap_drop: [ ALL ]
|
cap_drop: [ ALL ]
|
||||||
|
|
||||||
# ------------------------------------------------------------ API / gateway
|
# ------------------------------------------------------------ API / gateway
|
||||||
# api:
|
api:
|
||||||
# build:
|
build:
|
||||||
# context: .
|
context: .
|
||||||
# dockerfile: docker/api/Dockerfile
|
dockerfile: docker/api/Dockerfile
|
||||||
# restart: unless-stopped
|
|
||||||
# networks: [kbnet]
|
|
||||||
# depends_on: [postgres, rabbitmq, qdrant]
|
|
||||||
# environment:
|
|
||||||
# <<: *env
|
|
||||||
# JWT_SECRET_FILE: /run/secrets/jwt_secret
|
|
||||||
# OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
|
||||||
# POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
|
||||||
# QDRANT_URL: http://qdrant:6333
|
|
||||||
# secrets: [jwt_secret, openai_key, postgres_password]
|
|
||||||
# healthcheck:
|
|
||||||
# test: ["CMD-SHELL", "curl -fs http://localhost:8000/health || exit 1"]
|
|
||||||
# interval: 15s
|
|
||||||
# timeout: 5s
|
|
||||||
# retries: 5
|
|
||||||
# mem_limit: 768m
|
|
||||||
# cpus: "1"
|
|
||||||
# labels:
|
|
||||||
# - "traefik.enable=true"
|
|
||||||
# - "traefik.http.routers.kb.rule=Host(`${TRAEFIK_DOMAIN}`)"
|
|
||||||
# - "traefik.http.routers.kb.entrypoints=websecure"
|
|
||||||
# - "traefik.http.services.kb.loadbalancer.server.port=8000"
|
|
||||||
|
|
||||||
traefik:
|
|
||||||
image: traefik:v3.0
|
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [ kbnet ]
|
networks: [kbnet]
|
||||||
command:
|
depends_on: [postgres, rabbitmq, qdrant]
|
||||||
- "--providers.docker=true"
|
environment:
|
||||||
- "--providers.docker.network=kbnet"
|
<<: *env
|
||||||
- "--entrypoints.web.address=:80"
|
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||||
- "--entrypoints.websecure.address=:443"
|
QDRANT_URL: http://qdrant:6333
|
||||||
# - "--certificatesresolvers.le.acme.httpchallenge=true"
|
secrets: [postgres_password]
|
||||||
# - "--certificatesresolvers.le.acme.httpchallenge.entrypoint=web"
|
healthcheck:
|
||||||
# - "--certificatesresolvers.le.acme.email=${LE_EMAIL}"
|
test: ["CMD-SHELL", "curl -fs http://localhost:8000/health || exit 1"]
|
||||||
# - "--certificatesresolvers.le.acme.storage=/acme.json"
|
interval: 15s
|
||||||
- "--log.level=INFO"
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
ports:
|
ports:
|
||||||
- "80:80"
|
- "8000:8000"
|
||||||
- "443:443"
|
mem_limit: 768m
|
||||||
|
cpus: "1"
|
||||||
|
|
||||||
|
proxy:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: docker/api/Dockerfile
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [kbnet]
|
||||||
|
depends_on: [api]
|
||||||
|
environment:
|
||||||
|
<<: *env
|
||||||
|
PROXY_EMAIL: "${PROXY_EMAIL}"
|
||||||
|
PROXY_PASSWORD: "${PROXY_PASSWORD}"
|
||||||
|
PROXY_REMOTE_SERVER: "http://api:8000"
|
||||||
|
command: ["python", "/app/tools/simple_proxy.py", "--remote-server", "http://api:8000", "--email", "${PROXY_EMAIL}", "--password", "${PROXY_PASSWORD}", "--port", "8001"]
|
||||||
volumes:
|
volumes:
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
- ./tools:/app/tools:ro
|
||||||
# - ./acme.json:/acme.json:rw
|
ports:
|
||||||
|
- "8001:8001"
|
||||||
|
mem_limit: 256m
|
||||||
|
cpus: "0.5"
|
||||||
|
|
||||||
# ------------------------------------------------------------ Celery workers
|
# ------------------------------------------------------------ Celery workers
|
||||||
worker-email:
|
worker:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "email"
|
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes"
|
||||||
# deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
|
|
||||||
|
|
||||||
worker-text:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "medium_embed"
|
|
||||||
|
|
||||||
worker-ebook:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "ebooks"
|
|
||||||
|
|
||||||
worker-comic:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "comic"
|
|
||||||
|
|
||||||
worker-blogs:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "blogs"
|
|
||||||
|
|
||||||
worker-forums:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "forums"
|
|
||||||
|
|
||||||
worker-photo:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "photo_embed,comic"
|
|
||||||
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
|
|
||||||
|
|
||||||
worker-notes:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "notes"
|
|
||||||
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
|
|
||||||
|
|
||||||
worker-maintenance:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "maintenance"
|
|
||||||
# deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
|
||||||
|
|
||||||
ingest-hub:
|
ingest-hub:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
@ -245,35 +202,9 @@ services:
|
|||||||
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
||||||
|
|
||||||
# ------------------------------------------------------------ watchtower (auto-update)
|
# ------------------------------------------------------------ watchtower (auto-update)
|
||||||
watchtower:
|
# watchtower:
|
||||||
image: containrrr/watchtower
|
# image: containrrr/watchtower
|
||||||
restart: unless-stopped
|
# restart: unless-stopped
|
||||||
command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
|
# command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
|
||||||
volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
|
# volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
|
||||||
networks: [ kbnet ]
|
# networks: [ kbnet ]
|
||||||
|
|
||||||
# ------------------------------------------------------------------- profiles: observability (opt-in)
|
|
||||||
# services:
|
|
||||||
# prometheus:
|
|
||||||
# image: prom/prometheus:v2.52
|
|
||||||
# profiles: ["obs"]
|
|
||||||
# networks: [kbnet]
|
|
||||||
# volumes: [./observability/prometheus.yml:/etc/prometheus/prometheus.yml:ro]
|
|
||||||
# restart: unless-stopped
|
|
||||||
# ports: ["127.0.0.1:9090:9090"]
|
|
||||||
|
|
||||||
# grafana:
|
|
||||||
# image: grafana/grafana:10
|
|
||||||
# profiles: ["obs"]
|
|
||||||
# networks: [kbnet]
|
|
||||||
# volumes: [./observability/grafana:/var/lib/grafana]
|
|
||||||
# restart: unless-stopped
|
|
||||||
# environment:
|
|
||||||
# GF_SECURITY_ADMIN_USER: admin
|
|
||||||
# GF_SECURITY_ADMIN_PASSWORD_FILE: /run/secrets/grafana_pw
|
|
||||||
# secrets: [grafana_pw]
|
|
||||||
# ports: ["127.0.0.1:3000:3000"]
|
|
||||||
|
|
||||||
# secrets: # extra secret for Grafana, not needed otherwise
|
|
||||||
# grafana_pw:
|
|
||||||
# file: ./secrets/grafana_pw.txt
|
|
@ -1,24 +1,39 @@
|
|||||||
FROM python:3.10-slim
|
FROM python:3.11-slim
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy requirements files and setup
|
# Install build dependencies
|
||||||
COPY requirements-*.txt ./
|
RUN apt-get update && apt-get install -y \
|
||||||
COPY setup.py ./
|
gcc \
|
||||||
COPY src/ ./src/
|
g++ \
|
||||||
|
python3-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install the package with API dependencies
|
# Copy requirements files and setup
|
||||||
|
COPY requirements ./requirements/
|
||||||
|
RUN mkdir src
|
||||||
|
COPY setup.py ./
|
||||||
|
# Do an initial install to get the dependencies cached
|
||||||
|
RUN pip install -e ".[api]"
|
||||||
|
|
||||||
|
# Install the package with common dependencies
|
||||||
|
COPY src/ ./src/
|
||||||
RUN pip install -e ".[api]"
|
RUN pip install -e ".[api]"
|
||||||
|
|
||||||
# Run as non-root user
|
# Run as non-root user
|
||||||
RUN useradd -m appuser
|
RUN useradd -m appuser
|
||||||
USER appuser
|
RUN mkdir -p /app/memory_files
|
||||||
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
|
# Create user and set permissions
|
||||||
|
RUN useradd -m kb
|
||||||
|
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||||
|
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
|
||||||
|
|
||||||
|
USER kb
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PORT=8000
|
ENV PORT=8000
|
||||||
ENV PYTHONPATH="/app"
|
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
# Run the API
|
|
||||||
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
@ -5,17 +5,16 @@ WORKDIR /app
|
|||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
libpq-dev gcc supervisor && \
|
libpq-dev gcc supervisor && \
|
||||||
pip install -e ".[workers]" && \
|
|
||||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY requirements-*.txt ./
|
|
||||||
RUN pip install --no-cache-dir -r requirements-common.txt
|
|
||||||
RUN pip install --no-cache-dir -r requirements-parsers.txt
|
|
||||||
RUN pip install --no-cache-dir -r requirements-workers.txt
|
|
||||||
|
|
||||||
# Copy requirements files and setup
|
# Copy requirements files and setup
|
||||||
|
COPY requirements ./requirements/
|
||||||
COPY setup.py ./
|
COPY setup.py ./
|
||||||
|
RUN mkdir src
|
||||||
|
RUN pip install -e ".[common]"
|
||||||
|
|
||||||
COPY src/ ./src/
|
COPY src/ ./src/
|
||||||
|
RUN pip install -e ".[common]"
|
||||||
|
|
||||||
# Create and copy entrypoint script
|
# Create and copy entrypoint script
|
||||||
COPY docker/workers/entry.sh ./entry.sh
|
COPY docker/workers/entry.sh ./entry.sh
|
||||||
|
28
docker/migrations.Dockerfile
Normal file
28
docker/migrations.Dockerfile
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy requirements files and setup
|
||||||
|
COPY requirements ./requirements/
|
||||||
|
COPY setup.py ./
|
||||||
|
RUN mkdir src
|
||||||
|
RUN pip install -e ".[common]"
|
||||||
|
|
||||||
|
# Install the package with common dependencies
|
||||||
|
COPY src/ ./src/
|
||||||
|
RUN pip install -e ".[common]"
|
||||||
|
|
||||||
|
# Run as non-root user
|
||||||
|
RUN useradd -m appuser
|
||||||
|
RUN mkdir -p /app/memory_files
|
||||||
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
|
# Create user and set permissions
|
||||||
|
RUN useradd -m kb
|
||||||
|
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||||
|
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
|
||||||
|
|
||||||
|
USER kb
|
||||||
|
|
||||||
|
# Run the migrations
|
||||||
|
CMD ["alembic", "-c", "/app/db/migrations/alembic.ini", "upgrade", "head"]
|
@ -13,13 +13,12 @@ RUN apt-get update && apt-get install -y \
|
|||||||
# libreoffice-writer \
|
# libreoffice-writer \
|
||||||
&& apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
&& apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY requirements-*.txt ./
|
COPY requirements ./requirements/
|
||||||
RUN pip install --no-cache-dir -r requirements-common.txt
|
COPY setup.py ./
|
||||||
RUN pip install --no-cache-dir -r requirements-parsers.txt
|
RUN mkdir src
|
||||||
RUN pip install --no-cache-dir -r requirements-workers.txt
|
RUN pip install -e ".[common]"
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
COPY setup.py ./
|
|
||||||
COPY src/ ./src/
|
COPY src/ ./src/
|
||||||
RUN pip install -e ".[workers]"
|
RUN pip install -e ".[workers]"
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
QUEUE_PREFIX=${QUEUE_PREFIX:-memory}
|
||||||
QUEUES=${QUEUES:-default}
|
QUEUES=${QUEUES:-default}
|
||||||
|
QUEUES=$(IFS=,; echo "${QUEUES}" | tr ',' '\n' | sed "s/^/${QUEUE_PREFIX}-/" | paste -sd, -)
|
||||||
CONCURRENCY=${CONCURRENCY:-2}
|
CONCURRENCY=${CONCURRENCY:-2}
|
||||||
LOGLEVEL=${LOGLEVEL:-INFO}
|
LOGLEVEL=${LOGLEVEL:-INFO}
|
||||||
|
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
PyMuPDF==1.25.5
|
|
||||||
ebooklib==0.18.0
|
|
||||||
beautifulsoup4==4.13.4
|
|
||||||
markdownify==0.13.1
|
|
||||||
pillow==10.4.0
|
|
@ -1,5 +0,0 @@
|
|||||||
openai==1.25.0
|
|
||||||
pillow==10.4.0
|
|
||||||
pypandoc==1.15.0
|
|
||||||
beautifulsoup4==4.13.4
|
|
||||||
feedparser==6.0.10
|
|
@ -2,6 +2,6 @@ fastapi==0.112.2
|
|||||||
uvicorn==0.29.0
|
uvicorn==0.29.0
|
||||||
python-jose==3.3.0
|
python-jose==3.3.0
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
sqladmin
|
sqladmin==0.20.1
|
||||||
mcp==1.9.2
|
mcp==1.9.2
|
||||||
bm25s[full]==0.2.13
|
bm25s[full]==0.2.13
|
@ -1,11 +1,12 @@
|
|||||||
sqlalchemy==2.0.30
|
sqlalchemy==2.0.30
|
||||||
psycopg2-binary==2.9.9
|
psycopg2-binary==2.9.9
|
||||||
pydantic==2.7.1
|
pydantic==2.7.2
|
||||||
alembic==1.13.1
|
alembic==1.13.1
|
||||||
dotenv==0.9.9
|
dotenv==0.9.9
|
||||||
voyageai==0.3.2
|
voyageai==0.3.2
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
anthropic==0.18.1
|
anthropic==0.18.1
|
||||||
|
openai==1.25.0
|
||||||
# Pin the httpx version, as newer versions break the anthropic client
|
# Pin the httpx version, as newer versions break the anthropic client
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
celery==5.3.6
|
celery[sqs]==5.3.6
|
7
requirements/requirements-parsers.txt
Normal file
7
requirements/requirements-parsers.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
PyMuPDF==1.25.5
|
||||||
|
EbookLib==0.18
|
||||||
|
beautifulsoup4==4.13.4
|
||||||
|
markdownify==0.13.1
|
||||||
|
pillow==10.3.0
|
||||||
|
pypandoc==1.15
|
||||||
|
feedparser==6.0.10
|
14
setup.py
14
setup.py
@ -4,7 +4,7 @@ from setuptools import setup, find_namespace_packages
|
|||||||
|
|
||||||
def read_requirements(filename: str) -> list[str]:
|
def read_requirements(filename: str) -> list[str]:
|
||||||
"""Read requirements from file, ignoring comments and -r directives."""
|
"""Read requirements from file, ignoring comments and -r directives."""
|
||||||
path = pathlib.Path(filename)
|
path = pathlib.Path(__file__).parent / "requirements" / filename
|
||||||
return [
|
return [
|
||||||
line.strip()
|
line.strip()
|
||||||
for line in path.read_text().splitlines()
|
for line in path.read_text().splitlines()
|
||||||
@ -16,7 +16,6 @@ def read_requirements(filename: str) -> list[str]:
|
|||||||
common_requires = read_requirements("requirements-common.txt")
|
common_requires = read_requirements("requirements-common.txt")
|
||||||
parsers_requires = read_requirements("requirements-parsers.txt")
|
parsers_requires = read_requirements("requirements-parsers.txt")
|
||||||
api_requires = read_requirements("requirements-api.txt")
|
api_requires = read_requirements("requirements-api.txt")
|
||||||
workers_requires = read_requirements("requirements-workers.txt")
|
|
||||||
dev_requires = read_requirements("requirements-dev.txt")
|
dev_requires = read_requirements("requirements-dev.txt")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
@ -26,14 +25,9 @@ setup(
|
|||||||
packages=find_namespace_packages(where="src"),
|
packages=find_namespace_packages(where="src"),
|
||||||
python_requires=">=3.10",
|
python_requires=">=3.10",
|
||||||
extras_require={
|
extras_require={
|
||||||
"api": api_requires + common_requires,
|
"api": api_requires + common_requires + parsers_requires,
|
||||||
"workers": workers_requires + common_requires + parsers_requires,
|
"common": common_requires + parsers_requires,
|
||||||
"common": common_requires,
|
|
||||||
"dev": dev_requires,
|
"dev": dev_requires,
|
||||||
"all": api_requires
|
"all": api_requires + common_requires + dev_requires + parsers_requires,
|
||||||
+ workers_requires
|
|
||||||
+ common_requires
|
|
||||||
+ dev_requires
|
|
||||||
+ parsers_requires,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -182,6 +182,24 @@ async def observe(
|
|||||||
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
|
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
|
||||||
Be specific and detailed - observations should make sense months later.
|
Be specific and detailed - observations should make sense months later.
|
||||||
|
|
||||||
|
Example call:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"observations": [
|
||||||
|
{
|
||||||
|
"content": "The user is a software engineer.",
|
||||||
|
"subject": "user",
|
||||||
|
"observation_type": "belief",
|
||||||
|
"confidences": {"observation_accuracy": 0.9},
|
||||||
|
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
|
||||||
|
"tags": ["programming", "work"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||||
|
"agent_model": "gpt-4o"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
RawObservation fields:
|
RawObservation fields:
|
||||||
content (required): Detailed observation text explaining what you observed
|
content (required): Detailed observation text explaining what you observed
|
||||||
subject (required): Consistent identifier like "programming_style", "work_habits"
|
subject (required): Consistent identifier like "programming_style", "work_habits"
|
||||||
@ -200,7 +218,7 @@ async def observe(
|
|||||||
observation,
|
observation,
|
||||||
celery_app.send_task(
|
celery_app.send_task(
|
||||||
SYNC_OBSERVATION,
|
SYNC_OBSERVATION,
|
||||||
queue="notes",
|
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||||
kwargs={
|
kwargs={
|
||||||
"subject": observation.subject,
|
"subject": observation.subject,
|
||||||
"content": observation.content,
|
"content": observation.content,
|
||||||
@ -323,7 +341,7 @@ async def create_note(
|
|||||||
try:
|
try:
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
SYNC_NOTE,
|
SYNC_NOTE,
|
||||||
queue="notes",
|
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||||
kwargs={
|
kwargs={
|
||||||
"subject": subject,
|
"subject": subject,
|
||||||
"content": content,
|
"content": content,
|
||||||
|
@ -20,6 +20,7 @@ from memory.common.db.models import (
|
|||||||
ForumPost,
|
ForumPost,
|
||||||
AgentObservation,
|
AgentObservation,
|
||||||
Note,
|
Note,
|
||||||
|
User,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -207,6 +208,15 @@ class NoteAdmin(ModelView, model=Note):
|
|||||||
column_sortable_list = ["inserted_at"]
|
column_sortable_list = ["inserted_at"]
|
||||||
|
|
||||||
|
|
||||||
|
class UserAdmin(ModelView, model=User):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"email",
|
||||||
|
"name",
|
||||||
|
"created_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
"""Add all admin views to the admin instance."""
|
"""Add all admin views to the admin instance."""
|
||||||
admin.add_view(SourceItemAdmin)
|
admin.add_view(SourceItemAdmin)
|
||||||
@ -224,3 +234,4 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(ForumPostAdmin)
|
admin.add_view(ForumPostAdmin)
|
||||||
admin.add_view(ComicAdmin)
|
admin.add_view(ComicAdmin)
|
||||||
admin.add_view(PhotoAdmin)
|
admin.add_view(PhotoAdmin)
|
||||||
|
admin.add_view(UserAdmin)
|
||||||
|
@ -8,16 +8,30 @@ import pathlib
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Query, Form
|
from fastapi import (
|
||||||
|
FastAPI,
|
||||||
|
HTTPException,
|
||||||
|
File,
|
||||||
|
UploadFile,
|
||||||
|
Query,
|
||||||
|
Form,
|
||||||
|
Depends,
|
||||||
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from sqladmin import Admin
|
from sqladmin import Admin
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common import extract
|
from memory.common import extract
|
||||||
from memory.common.db.connection import get_engine
|
from memory.common.db.connection import get_engine
|
||||||
|
from memory.common.db.models import User
|
||||||
from memory.api.admin import setup_admin
|
from memory.api.admin import setup_admin
|
||||||
from memory.api.search import search, SearchResult
|
from memory.api.search import search, SearchResult
|
||||||
from memory.api.MCP.tools import mcp
|
from memory.api.MCP.tools import mcp
|
||||||
|
from memory.api.auth import (
|
||||||
|
router as auth_router,
|
||||||
|
get_current_user,
|
||||||
|
AuthenticationMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -35,6 +49,10 @@ app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
|||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
|
|
||||||
|
# Include auth router
|
||||||
|
app.add_middleware(AuthenticationMiddleware)
|
||||||
|
app.include_router(auth_router)
|
||||||
app.mount("/", mcp.streamable_http_app())
|
app.mount("/", mcp.streamable_http_app())
|
||||||
|
|
||||||
|
|
||||||
@ -63,6 +81,7 @@ async def search_endpoint(
|
|||||||
limit: int = Query(10, ge=1, le=100),
|
limit: int = Query(10, ge=1, le=100),
|
||||||
min_text_score: float = Query(0.3, ge=0.0, le=1.0),
|
min_text_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||||
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Search endpoint - delegates to search module"""
|
"""Search endpoint - delegates to search module"""
|
||||||
upload_data = [
|
upload_data = [
|
||||||
@ -74,7 +93,7 @@ async def search_endpoint(
|
|||||||
return await search(
|
return await search(
|
||||||
upload_data,
|
upload_data,
|
||||||
previews=previews,
|
previews=previews,
|
||||||
modalities=modalities,
|
modalities=set(modalities),
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_text_score=min_text_score,
|
min_text_score=min_text_score,
|
||||||
min_multimodal_score=min_multimodal_score,
|
min_multimodal_score=min_multimodal_score,
|
||||||
@ -82,7 +101,7 @@ async def search_endpoint(
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/files/{path:path}")
|
@app.get("/files/{path:path}")
|
||||||
def get_file_by_path(path: str):
|
def get_file_by_path(path: str, current_user: User = Depends(get_current_user)):
|
||||||
"""
|
"""
|
||||||
Fetch a file by its path
|
Fetch a file by its path
|
||||||
|
|
||||||
|
228
src/memory/api/auth.py
Normal file
228
src/memory/api/auth.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import textwrap
|
||||||
|
from typing import cast
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from memory.common import settings
|
||||||
|
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from memory.common.db.connection import get_session, make_session
|
||||||
|
from memory.common.db.models.users import User, UserSession
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Pydantic models
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoginResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
user_id: int
|
||||||
|
email: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
# Create router
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
def create_user_session(
|
||||||
|
user_id: int, db: DBSession, valid_for: int = settings.SESSION_VALID_FOR
|
||||||
|
) -> str:
|
||||||
|
"""Create a new session for a user"""
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(days=valid_for)
|
||||||
|
|
||||||
|
session = UserSession(user_id=user_id, expires_at=expires_at)
|
||||||
|
db.add(session)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return str(session.id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
|
||||||
|
"""Get user from session ID if session is valid"""
|
||||||
|
session = db.query(UserSession).get(session_id)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if session.expires_at.replace(tzinfo=timezone.utc) > now:
|
||||||
|
return session.user
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
||||||
|
"""FastAPI dependency to get current authenticated user"""
|
||||||
|
# Check for session ID in header or cookie
|
||||||
|
session_id = request.headers.get(
|
||||||
|
settings.SESSION_HEADER_NAME
|
||||||
|
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
raise HTTPException(status_code=401, detail="No session provided")
|
||||||
|
|
||||||
|
user = get_session_user(session_id, db)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(email: str, password: str, name: str, db: DBSession) -> User:
|
||||||
|
"""Create a new user"""
|
||||||
|
# Check if user already exists
|
||||||
|
existing_user = db.query(User).filter(User.email == email).first()
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(status_code=400, detail="User already exists")
|
||||||
|
|
||||||
|
user = User.create_with_password(email, name, password)
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user(email: str, password: str, db: DBSession) -> User | None:
|
||||||
|
"""Authenticate a user by email and password"""
|
||||||
|
user = db.query(User).filter(User.email == email).first()
|
||||||
|
if user and user.is_valid_password(password):
|
||||||
|
return user
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Auth endpoints
|
||||||
|
@router.post("/register", response_model=LoginResponse)
|
||||||
|
def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
|
||||||
|
"""Register a new user"""
|
||||||
|
if not settings.REGISTER_ENABLED:
|
||||||
|
raise HTTPException(status_code=403, detail="Registration is disabled")
|
||||||
|
|
||||||
|
user = create_user(request.email, request.password, request.name, db)
|
||||||
|
session_id = create_user_session(user.id, db) # type: ignore
|
||||||
|
|
||||||
|
return LoginResponse(session_id=session_id, **user.serialize())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/login", response_model=LoginResponse)
|
||||||
|
def login_page():
|
||||||
|
"""Login page"""
|
||||||
|
return HTMLResponse(
|
||||||
|
content=textwrap.dedent("""
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
<h1>Login</h1>
|
||||||
|
<form method="post" action="/auth/login-form">
|
||||||
|
<input type="email" name="email" placeholder="Email" />
|
||||||
|
<input type="password" name="password" placeholder="Password" />
|
||||||
|
<button type="submit">Login</button>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=LoginResponse)
|
||||||
|
def login(
|
||||||
|
request: LoginRequest, response: Response, db: DBSession = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""Login and create a session"""
|
||||||
|
return login_form(response, db, request.email, request.password)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login-form", response_model=LoginResponse)
|
||||||
|
def login_form(
|
||||||
|
response: Response,
|
||||||
|
db: DBSession = Depends(get_session),
|
||||||
|
email: str = Form(),
|
||||||
|
password: str = Form(),
|
||||||
|
):
|
||||||
|
"""Login with form data and create a session"""
|
||||||
|
user = authenticate_user(email, password, db)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
|
|
||||||
|
session_id = create_user_session(cast(int, user.id), db)
|
||||||
|
|
||||||
|
# Set session cookie
|
||||||
|
response.set_cookie(
|
||||||
|
key=settings.SESSION_COOKIE_NAME,
|
||||||
|
value=session_id,
|
||||||
|
httponly=True,
|
||||||
|
secure=settings.HTTPS,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=settings.SESSION_COOKIE_MAX_AGE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LoginResponse(session_id=session_id, **user.serialize())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
def logout(response: Response, user: User = Depends(get_current_user)):
|
||||||
|
"""Logout and clear session"""
|
||||||
|
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
||||||
|
return {"message": "Logged out successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me")
|
||||||
|
def get_me(user: User = Depends(get_current_user)):
|
||||||
|
"""Get current user info"""
|
||||||
|
return user.serialize()
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to require authentication for all endpoints except whitelisted ones."""
|
||||||
|
|
||||||
|
# Endpoints that don't require authentication
|
||||||
|
WHITELIST = {
|
||||||
|
"/health",
|
||||||
|
"/auth/login",
|
||||||
|
"/auth/login-form",
|
||||||
|
"/auth/register",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# Skip authentication for whitelisted endpoints
|
||||||
|
if any(path.startswith(whitelist_path) for whitelist_path in self.WHITELIST):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Check for session ID in header or cookie
|
||||||
|
session_id = request.headers.get(
|
||||||
|
settings.SESSION_HEADER_NAME
|
||||||
|
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
return Response(
|
||||||
|
content="Authentication required",
|
||||||
|
status_code=401,
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate session and get user
|
||||||
|
with make_session() as session:
|
||||||
|
user = get_session_user(session_id, session)
|
||||||
|
if not user:
|
||||||
|
return Response(
|
||||||
|
content="Invalid or expired session",
|
||||||
|
status_code=401,
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Authenticated request from user {user.email} to {path}")
|
||||||
|
|
||||||
|
return await call_next(request)
|
@ -1,4 +1,5 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
|
from kombu.utils.url import safequote
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
EMAIL_ROOT = "memory.workers.tasks.email"
|
EMAIL_ROOT = "memory.workers.tasks.email"
|
||||||
@ -41,13 +42,17 @@ SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
|
|||||||
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
||||||
|
|
||||||
|
|
||||||
def rabbit_url() -> str:
|
def get_broker_url() -> str:
|
||||||
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:{settings.RABBITMQ_PORT}//"
|
protocol = settings.CELERY_BROKER_TYPE
|
||||||
|
user = safequote(settings.CELERY_BROKER_USER)
|
||||||
|
password = safequote(settings.CELERY_BROKER_PASSWORD)
|
||||||
|
host = settings.CELERY_BROKER_HOST
|
||||||
|
return f"{protocol}://{user}:{password}@{host}"
|
||||||
|
|
||||||
|
|
||||||
app = Celery(
|
app = Celery(
|
||||||
"memory",
|
"memory",
|
||||||
broker=rabbit_url(),
|
broker=get_broker_url(),
|
||||||
backend=settings.CELERY_RESULT_BACKEND,
|
backend=settings.CELERY_RESULT_BACKEND,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,20 +64,22 @@ app.conf.update(
|
|||||||
task_reject_on_worker_lost=True,
|
task_reject_on_worker_lost=True,
|
||||||
worker_prefetch_multiplier=1,
|
worker_prefetch_multiplier=1,
|
||||||
task_routes={
|
task_routes={
|
||||||
f"{EMAIL_ROOT}.*": {"queue": "email"},
|
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
||||||
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"},
|
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
||||||
f"{COMIC_ROOT}.*": {"queue": "comic"},
|
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
||||||
f"{EBOOK_ROOT}.*": {"queue": "ebooks"},
|
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
||||||
f"{BLOGS_ROOT}.*": {"queue": "blogs"},
|
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
||||||
f"{FORUMS_ROOT}.*": {"queue": "forums"},
|
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
|
||||||
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"},
|
f"{MAINTENANCE_ROOT}.*": {
|
||||||
f"{NOTES_ROOT}.*": {"queue": "notes"},
|
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
|
||||||
f"{OBSERVATIONS_ROOT}.*": {"queue": "notes"},
|
},
|
||||||
|
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
|
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.on_after_configure.connect # type: ignore
|
@app.on_after_configure.connect # type: ignore[attr-defined]
|
||||||
def ensure_qdrant_initialised(sender, **_):
|
def ensure_qdrant_initialised(sender, **_):
|
||||||
from memory.common import qdrant
|
from memory.common import qdrant
|
||||||
|
|
||||||
|
@ -3,8 +3,9 @@ Database connection utilities.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Generator
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
from sqlalchemy.orm import sessionmaker, scoped_session, Session
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
@ -23,9 +24,17 @@ def get_session_factory():
|
|||||||
|
|
||||||
def get_scoped_session():
|
def get_scoped_session():
|
||||||
"""Create a thread-local scoped session factory"""
|
"""Create a thread-local scoped session factory"""
|
||||||
engine = get_engine()
|
return scoped_session(get_session_factory())
|
||||||
session_factory = sessionmaker(bind=engine)
|
|
||||||
return scoped_session(session_factory)
|
|
||||||
|
def get_session() -> Generator[Session, None, None]:
|
||||||
|
"""FastAPI dependency for database sessions"""
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
session = session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -32,6 +32,10 @@ from memory.common.db.models.sources import (
|
|||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.users import (
|
||||||
|
User,
|
||||||
|
UserSession,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Base",
|
"Base",
|
||||||
@ -62,4 +66,7 @@ __all__ = [
|
|||||||
"Book",
|
"Book",
|
||||||
"ArticleFeed",
|
"ArticleFeed",
|
||||||
"EmailAccount",
|
"EmailAccount",
|
||||||
|
# Users
|
||||||
|
"User",
|
||||||
|
"UserSession",
|
||||||
]
|
]
|
||||||
|
65
src/memory/common/db/models/users.py
Normal file
65
src/memory/common/db/models/users.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
from typing import cast
|
||||||
|
import uuid
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""Hash a password using SHA-256 with salt"""
|
||||||
|
salt = secrets.token_hex(16)
|
||||||
|
return f"{salt}:{hashlib.sha256((salt + password).encode()).hexdigest()}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, password_hash: str) -> bool:
|
||||||
|
"""Verify a password against its hash"""
|
||||||
|
try:
|
||||||
|
salt, hash_value = password_hash.split(":", 1)
|
||||||
|
return hashlib.sha256((salt + password).encode()).hexdigest() == hash_value
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
email = Column(String, nullable=False, unique=True)
|
||||||
|
password_hash = Column(String, nullable=False)
|
||||||
|
|
||||||
|
# Relationship to sessions
|
||||||
|
sessions = relationship(
|
||||||
|
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize(self) -> dict:
|
||||||
|
return {
|
||||||
|
"user_id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"email": self.email,
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_valid_password(self, password: str) -> bool:
|
||||||
|
"""Check if the provided password is valid for this user"""
|
||||||
|
return verify_password(password, cast(str, self.password_hash))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_with_password(cls, email: str, name: str, password: str) -> "User":
|
||||||
|
"""Create a new user with a hashed password"""
|
||||||
|
return cls(email=email, name=name, password_hash=hash_password(password))
|
||||||
|
|
||||||
|
|
||||||
|
class UserSession(Base):
|
||||||
|
__tablename__ = "user_sessions"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
|
expires_at = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
|
# Relationship to user
|
||||||
|
user = relationship("User", back_populates="sessions")
|
@ -29,11 +29,18 @@ def make_db_url(
|
|||||||
|
|
||||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||||
|
|
||||||
# Celery settings
|
|
||||||
RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb")
|
# Broker settings
|
||||||
RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb")
|
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "amqp").lower() # amqp or sqs
|
||||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "kb")
|
||||||
|
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", "kb")
|
||||||
|
|
||||||
|
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
||||||
|
if not CELERY_BROKER_HOST and CELERY_BROKER_TYPE == "amqp":
|
||||||
|
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||||
|
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||||
|
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
||||||
|
|
||||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||||
|
|
||||||
@ -122,3 +129,12 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
|||||||
else:
|
else:
|
||||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
|
||||||
|
# API settings
|
||||||
|
HTTPS = boolean_env("HTTPS", False)
|
||||||
|
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
||||||
|
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||||
|
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
||||||
|
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||||
|
|
||||||
|
REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True
|
||||||
|
20
tools/add_user.py
Normal file
20
tools/add_user.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#! /usr/bin/env python
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models.users import User
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = argparse.ArgumentParser()
|
||||||
|
args.add_argument("--email", type=str, required=True)
|
||||||
|
args.add_argument("--password", type=str, required=True)
|
||||||
|
args.add_argument("--name", type=str, required=True)
|
||||||
|
args = args.parse_args()
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
user = User(email=args.email, password=args.password, name=args.name)
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
print(f"User {args.email} created")
|
120
tools/simple_proxy.py
Normal file
120
tools/simple_proxy.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import uvicorn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
|
||||||
|
class State(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
remote_server: str
|
||||||
|
session_header: str = "X-Session-ID"
|
||||||
|
session_id: str | None = None
|
||||||
|
port: int = 8080
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> State:
|
||||||
|
"""Parse command line arguments"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Simple HTTP proxy with authentication"
|
||||||
|
)
|
||||||
|
parser.add_argument("--remote-server", required=True, help="Remote server URL")
|
||||||
|
parser.add_argument("--email", required=True, help="Email for authentication")
|
||||||
|
parser.add_argument("--password", required=True, help="Password for authentication")
|
||||||
|
parser.add_argument(
|
||||||
|
"--session-header", default="X-Session-ID", help="Session header name"
|
||||||
|
)
|
||||||
|
parser.add_argument("--port", type=int, default=8080, help="Port to run proxy on")
|
||||||
|
return State(**vars(parser.parse_args()))
|
||||||
|
|
||||||
|
|
||||||
|
state = parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def login() -> None:
|
||||||
|
"""Login to remote server and store session ID"""
|
||||||
|
login_url = f"{state.remote_server}/auth/login"
|
||||||
|
login_data = {"email": state.email, "password": state.password}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(login_url, json=login_data)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
login_response = response.json()
|
||||||
|
state.session_id = login_response["session_id"]
|
||||||
|
print(f"Successfully logged in, session ID: {state.session_id}")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
print(
|
||||||
|
f"Login failed with status {e.response.status_code}: {e.response.text}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Login failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_request(request: Request) -> Response:
|
||||||
|
"""Proxy request to remote server with session header"""
|
||||||
|
if not state.session_id:
|
||||||
|
try:
|
||||||
|
await login()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Login failed: {e}")
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
# Build the target URL
|
||||||
|
target_url = f"{state.remote_server}{request.url.path}"
|
||||||
|
if request.url.query:
|
||||||
|
target_url += f"?{request.url.query}"
|
||||||
|
|
||||||
|
# Get request body
|
||||||
|
body = await request.body()
|
||||||
|
headers = dict(request.headers)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=target_url,
|
||||||
|
headers=headers | {state.session_header: state.session_id}, # type: ignore
|
||||||
|
content=body,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward response
|
||||||
|
return Response(
|
||||||
|
content=response.content,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.headers.get("content-type"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print(f"Request failed: {e}")
|
||||||
|
raise HTTPException(status_code=502, detail=f"Proxy request failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(title="Simple Proxy")
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route(
|
||||||
|
"/{path:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
|
||||||
|
)
|
||||||
|
async def proxy_all(request: Request):
|
||||||
|
"""Proxy all requests to remote server"""
|
||||||
|
return await proxy_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(f"Starting proxy server on port {state.port}")
|
||||||
|
print(f"Proxying to: {state.remote_server}")
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=state.port)
|
Loading…
x
Reference in New Issue
Block a user