packageable + proxy

This commit is contained in:
Daniel O'Connell 2025-06-03 18:48:45 +02:00
parent 0551ddd30c
commit f2c24cca3b
28 changed files with 983 additions and 205 deletions

210
README.md Normal file
View File

@ -0,0 +1,210 @@
# Memory - Personal Knowledge Base
A personal knowledge base system that ingests, indexes, and provides semantic search over various content types including emails, documents, notes, web pages, and more. Features MCP (Model Context Protocol) integration for AI assistants to access and learn from your personal data.
## Features
- **Multi-modal Content Ingestion**: Process emails, documents, ebooks, comics, web pages, and more
- **Semantic Search**: Vector-based search across all your content with relevance scoring
- **MCP Integration**: Direct integration with AI assistants via Model Context Protocol
- **Observation System**: AI assistants can record and search long-term observations about user preferences and patterns
- **Note Taking**: Create and organize markdown notes with full-text search
- **User Management**: Multi-user support with authentication
- **RESTful API**: Complete API for programmatic access
- **Real-time Processing**: Celery-based background processing for content ingestion
## Quick Start
### Prerequisites
- Docker and Docker Compose
- Python 3.11+ (for tools)
### 1. Start the Development Environment
```bash
# Clone the repository and navigate to it
cd memory
# Start the core services (PostgreSQL, RabbitMQ, Qdrant)
./dev.sh
```
This will:
- Start PostgreSQL (exposed on port 5432)
- Start RabbitMQ with management interface
- Start Qdrant vector database
- Initialize the database schema
It will also generate secrets in `secrets` and make a basic `.env` file for you.
### 2. Start the Full Application
```bash
# Start all services including API and workers
docker-compose up -d
# Check that services are healthy
docker-compose ps
```
The API will be available at `http://localhost:8000`
The is also an admin interface at `http://localhost:8000/admin` where you can see what the database
contains.
## User Management
### Create a User
```bash
# Install the package in development mode
pip install -e ".[all]"
# Create a new user
python tools/add_user.py --email user@example.com --password yourpassword --name "Your Name"
```
### Authentication
The API uses session-based authentication. Login via:
```bash
curl -X POST http://localhost:8000/auth/login \
-H "Content-Type: application/json" \
-d '{"email": "user@example.com", "password": "yourpassword"}'
```
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
## MCP Proxy Setup
Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect:
### Start the Proxy
```bash
python tools/simple_proxy.py \
--remote-server http://localhost:8000 \
--email user@example.com \
--password yourpassword \
--port 8080
```
### Configure Your AI Assistant
Point your MCP-compatible AI assistant to `http://localhost:8080` instead of the direct API endpoint. The proxy will:
- Handle authentication automatically
- Forward all requests to the main API
- Add the session header to each request
### Example MCP Configuration
For Claude Desktop or other MCP clients, add to your configuration:
```json
{
"mcpServers": {
"memory": {
"type": "streamable-http",
"url": "http://localhost:8001/mcp",
}
}
}
```
## Available MCP Tools
When connected via MCP, AI assistants have access to:
- `search_knowledge_base()` - Search your stored content
- `search_observations()` - Search recorded observations about you
- `observe()` - Record new observations about your preferences/behavior
- `create_note()` - Create and save notes
- `note_files()` - List existing notes
- `fetch_file()` - Read file contents
- `get_all_tags()` - Get all content tags
- `get_all_subjects()` - Get observation subjects
## Content Ingestion
### Via Workers
Content is processed asynchronously by Celery workers. Supported formats include:
- PDFs, DOCX, TXT files
- Emails (mbox, EML formats)
- Web pages (HTML)
- Ebooks (EPUB, PDF)
- Images with OCR
- And more...
## Development
### Environment Setup
```bash
# Install development dependencies
pip install -e ".[dev]"
# Run tests
pytest
# Run with auto-reload
RELOAD=true python -m memory.api.app
```
### Architecture
- **FastAPI**: REST API and MCP server
- **PostgreSQL**: Primary database for metadata and users
- **Qdrant**: Vector database for semantic search
- **RabbitMQ**: Message queue for background processing
- **Celery**: Distributed task processing
- **SQLAdmin**: Admin interface for database management
### Configuration
Key environment variables:
- `FILE_STORAGE_DIR`: Where uploaded files are stored
- `DB_HOST`, `DB_PORT`: Database connection
- `QDRANT_HOST`: Vector database connection
- `RABBITMQ_HOST`: Message queue connection
See `docker-compose.yaml` for full configuration options.
## Security Notes
- Never expose the main API directly to the internet without proper authentication
- Use the proxy for MCP connections to handle authentication securely
- Store secrets in the `secrets/` directory (see `docker-compose.yaml`)
- The application runs with minimal privileges in Docker containers
## Troubleshooting
### Common Issues
1. **Database connection errors**: Ensure PostgreSQL is running and accessible
2. **Vector search not working**: Check that Qdrant is healthy
3. **Background processing stalled**: Verify RabbitMQ and Celery workers are running
4. **MCP connection issues**: Use the proxy instead of direct API access
### Logs
```bash
# View API logs
docker-compose logs -f api
# View worker logs
docker-compose logs -f worker
# View all logs
docker-compose logs -f
```
## Contributing
This is a personal knowledge base system. Feel free to fork and adapt for your own use cases.

View File

@ -0,0 +1,49 @@
"""Add user
Revision ID: 77cdbfc882e2
Revises: 152f8b4b52e8
Create Date: 2025-06-03 16:48:59.509683
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "77cdbfc882e2"
down_revision: Union[str, None] = "152f8b4b52e8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=False),
sa.Column("password_hash", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email"),
)
op.create_table(
"user_sessions",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("user_sessions")
op.drop_table("users")

35
dev.sh
View File

@ -16,15 +16,48 @@ cd "$SCRIPT_DIR"
docker volume create memory_file_storage docker volume create memory_file_storage
docker run --rm -v memory_file_storage:/data busybox chown -R 1000:1000 /data docker run --rm -v memory_file_storage:/data busybox chown -R 1000:1000 /data
POSTGRES_PASSWORD=543218ZrHw8Pxbs3YXzaVHq8YKVHwCj6Pz8RQkl8
echo $POSTGRES_PASSWORD > secrets/postgres_password.txt
# Create a temporary docker-compose override file to expose PostgreSQL # Create a temporary docker-compose override file to expose PostgreSQL
echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}" echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}"
if [ ! -f docker-compose.override.yml ]; then if [ ! -f docker-compose.override.yml ]; then
cat > docker-compose.override.yml << EOL cat > docker-compose.override.yml << EOL
version: "3.9" version: "3.9"
services: services:
qdrant:
ports:
- "6333:6333"
postgres: postgres:
ports: ports:
- "5432:5432" # PostgreSQL port for local Celery result backend
- "15432:5432"
rabbitmq:
ports:
# UI only on localhost
- "15672:15672"
# AMQP port for local Celery clients (for local workers)
- "15673:5672"
EOL
fi
if [ ! -f .env ]; then
echo $POSTGRES_PASSWORD > .env
cat >> .env << EOL
CELERY_BROKER_PASSWORD=543218ZrHw8Pxbs3YXzaVHq8YKVHwCj6Pz8RQkl8
RABBITMQ_HOST=localhost
QDRANT_HOST=localhost
DB_HOST=localhost
VOYAGE_API_KEY=
ANTHROPIC_API_KEY=
OPENAI_API_KEY=
DB_PORT=15432
RABBITMQ_PORT=15673
EOL EOL
fi fi

View File

@ -9,7 +9,6 @@ networks:
# --------------------------------------------------------------------- secrets # --------------------------------------------------------------------- secrets
secrets: secrets:
postgres_password: { file: ./secrets/postgres_password.txt } postgres_password: { file: ./secrets/postgres_password.txt }
jwt_secret: { file: ./secrets/jwt_secret.txt }
openai_key: { file: ./secrets/openai_key.txt } openai_key: { file: ./secrets/openai_key.txt }
anthropic_key: { file: ./secrets/anthropic_key.txt } anthropic_key: { file: ./secrets/anthropic_key.txt }
@ -23,6 +22,7 @@ volumes:
x-common-env: &env x-common-env: &env
RABBITMQ_USER: kb RABBITMQ_USER: kb
RABBITMQ_HOST: rabbitmq RABBITMQ_HOST: rabbitmq
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
QDRANT_HOST: qdrant QDRANT_HOST: qdrant
DB_HOST: postgres DB_HOST: postgres
DB_PORT: 5432 DB_PORT: 5432
@ -81,6 +81,21 @@ services:
cpus: "1.5" cpus: "1.5"
security_opt: [ "no-new-privileges=true" ] security_opt: [ "no-new-privileges=true" ]
migrate:
build:
context: .
dockerfile: docker/migrations.Dockerfile
networks: [kbnet]
depends_on:
postgres:
condition: service_healthy
environment:
<<: *env
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
secrets: [postgres_password]
volumes:
- ./db:/app/db:ro
rabbitmq: rabbitmq:
image: rabbitmq:3.13-management image: rabbitmq:3.13-management
restart: unless-stopped restart: unless-stopped
@ -88,7 +103,7 @@ services:
environment: environment:
<<: *env <<: *env
RABBITMQ_DEFAULT_USER: "kb" RABBITMQ_DEFAULT_USER: "kb"
RABBITMQ_DEFAULT_PASS: "${RABBITMQ_PASSWORD}" RABBITMQ_DEFAULT_PASS: "${CELERY_BROKER_PASSWORD}"
volumes: volumes:
- rabbitmq_data:/var/lib/rabbitmq:rw - rabbitmq_data:/var/lib/rabbitmq:rw
healthcheck: healthcheck:
@ -121,112 +136,54 @@ services:
cap_drop: [ ALL ] cap_drop: [ ALL ]
# ------------------------------------------------------------ API / gateway # ------------------------------------------------------------ API / gateway
# api: api:
# build: build:
# context: . context: .
# dockerfile: docker/api/Dockerfile dockerfile: docker/api/Dockerfile
# restart: unless-stopped
# networks: [kbnet]
# depends_on: [postgres, rabbitmq, qdrant]
# environment:
# <<: *env
# JWT_SECRET_FILE: /run/secrets/jwt_secret
# OPENAI_API_KEY_FILE: /run/secrets/openai_key
# POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
# QDRANT_URL: http://qdrant:6333
# secrets: [jwt_secret, openai_key, postgres_password]
# healthcheck:
# test: ["CMD-SHELL", "curl -fs http://localhost:8000/health || exit 1"]
# interval: 15s
# timeout: 5s
# retries: 5
# mem_limit: 768m
# cpus: "1"
# labels:
# - "traefik.enable=true"
# - "traefik.http.routers.kb.rule=Host(`${TRAEFIK_DOMAIN}`)"
# - "traefik.http.routers.kb.entrypoints=websecure"
# - "traefik.http.services.kb.loadbalancer.server.port=8000"
traefik:
image: traefik:v3.0
restart: unless-stopped restart: unless-stopped
networks: [ kbnet ] networks: [kbnet]
command: depends_on: [postgres, rabbitmq, qdrant]
- "--providers.docker=true" environment:
- "--providers.docker.network=kbnet" <<: *env
- "--entrypoints.web.address=:80" POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
- "--entrypoints.websecure.address=:443" QDRANT_URL: http://qdrant:6333
# - "--certificatesresolvers.le.acme.httpchallenge=true" secrets: [postgres_password]
# - "--certificatesresolvers.le.acme.httpchallenge.entrypoint=web" healthcheck:
# - "--certificatesresolvers.le.acme.email=${LE_EMAIL}" test: ["CMD-SHELL", "curl -fs http://localhost:8000/health || exit 1"]
# - "--certificatesresolvers.le.acme.storage=/acme.json" interval: 15s
- "--log.level=INFO" timeout: 5s
retries: 5
ports: ports:
- "80:80" - "8000:8000"
- "443:443" mem_limit: 768m
cpus: "1"
proxy:
build:
context: .
dockerfile: docker/api/Dockerfile
restart: unless-stopped
networks: [kbnet]
depends_on: [api]
environment:
<<: *env
PROXY_EMAIL: "${PROXY_EMAIL}"
PROXY_PASSWORD: "${PROXY_PASSWORD}"
PROXY_REMOTE_SERVER: "http://api:8000"
command: ["python", "/app/tools/simple_proxy.py", "--remote-server", "http://api:8000", "--email", "${PROXY_EMAIL}", "--password", "${PROXY_PASSWORD}", "--port", "8001"]
volumes: volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro - ./tools:/app/tools:ro
# - ./acme.json:/acme.json:rw ports:
- "8001:8001"
mem_limit: 256m
cpus: "0.5"
# ------------------------------------------------------------ Celery workers # ------------------------------------------------------------ Celery workers
worker-email: worker:
<<: *worker-base <<: *worker-base
environment: environment:
<<: *worker-env <<: *worker-env
QUEUES: "email" QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes"
# deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
worker-text:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "medium_embed"
worker-ebook:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "ebooks"
worker-comic:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "comic"
worker-blogs:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "blogs"
worker-forums:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "forums"
worker-photo:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "photo_embed,comic"
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
worker-notes:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "notes"
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
worker-maintenance:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "maintenance"
# deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
ingest-hub: ingest-hub:
<<: *worker-base <<: *worker-base
@ -245,35 +202,9 @@ services:
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } } deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
# ------------------------------------------------------------ watchtower (auto-update) # ------------------------------------------------------------ watchtower (auto-update)
watchtower: # watchtower:
image: containrrr/watchtower # image: containrrr/watchtower
restart: unless-stopped # restart: unless-stopped
command: [ "--schedule", "0 0 4 * * *", "--cleanup" ] # command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ] # volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
networks: [ kbnet ] # networks: [ kbnet ]
# ------------------------------------------------------------------- profiles: observability (opt-in)
# services:
# prometheus:
# image: prom/prometheus:v2.52
# profiles: ["obs"]
# networks: [kbnet]
# volumes: [./observability/prometheus.yml:/etc/prometheus/prometheus.yml:ro]
# restart: unless-stopped
# ports: ["127.0.0.1:9090:9090"]
# grafana:
# image: grafana/grafana:10
# profiles: ["obs"]
# networks: [kbnet]
# volumes: [./observability/grafana:/var/lib/grafana]
# restart: unless-stopped
# environment:
# GF_SECURITY_ADMIN_USER: admin
# GF_SECURITY_ADMIN_PASSWORD_FILE: /run/secrets/grafana_pw
# secrets: [grafana_pw]
# ports: ["127.0.0.1:3000:3000"]
# secrets: # extra secret for Grafana, not needed otherwise
# grafana_pw:
# file: ./secrets/grafana_pw.txt

View File

@ -1,24 +1,39 @@
FROM python:3.10-slim FROM python:3.11-slim
WORKDIR /app WORKDIR /app
# Copy requirements files and setup # Install build dependencies
COPY requirements-*.txt ./ RUN apt-get update && apt-get install -y \
COPY setup.py ./ gcc \
COPY src/ ./src/ g++ \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
# Install the package with API dependencies # Copy requirements files and setup
COPY requirements ./requirements/
RUN mkdir src
COPY setup.py ./
# Do an initial install to get the dependencies cached
RUN pip install -e ".[api]"
# Install the package with common dependencies
COPY src/ ./src/
RUN pip install -e ".[api]" RUN pip install -e ".[api]"
# Run as non-root user # Run as non-root user
RUN useradd -m appuser RUN useradd -m appuser
USER appuser RUN mkdir -p /app/memory_files
ENV PYTHONPATH="/app"
# Create user and set permissions
RUN useradd -m kb
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
USER kb
# Set environment variables # Set environment variables
ENV PORT=8000 ENV PORT=8000
ENV PYTHONPATH="/app"
EXPOSE 8000 EXPOSE 8000
# Run the API
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"] CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@ -5,17 +5,16 @@ WORKDIR /app
# Install dependencies # Install dependencies
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
libpq-dev gcc supervisor && \ libpq-dev gcc supervisor && \
pip install -e ".[workers]" && \
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
COPY requirements-*.txt ./
RUN pip install --no-cache-dir -r requirements-common.txt
RUN pip install --no-cache-dir -r requirements-parsers.txt
RUN pip install --no-cache-dir -r requirements-workers.txt
# Copy requirements files and setup # Copy requirements files and setup
COPY requirements ./requirements/
COPY setup.py ./ COPY setup.py ./
RUN mkdir src
RUN pip install -e ".[common]"
COPY src/ ./src/ COPY src/ ./src/
RUN pip install -e ".[common]"
# Create and copy entrypoint script # Create and copy entrypoint script
COPY docker/workers/entry.sh ./entry.sh COPY docker/workers/entry.sh ./entry.sh

View File

@ -0,0 +1,28 @@
FROM python:3.11-slim
WORKDIR /app
# Copy requirements files and setup
COPY requirements ./requirements/
COPY setup.py ./
RUN mkdir src
RUN pip install -e ".[common]"
# Install the package with common dependencies
COPY src/ ./src/
RUN pip install -e ".[common]"
# Run as non-root user
RUN useradd -m appuser
RUN mkdir -p /app/memory_files
ENV PYTHONPATH="/app"
# Create user and set permissions
RUN useradd -m kb
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
chown -R kb:kb /var/cache/fontconfig /home/kb/.cache/fontconfig /app
USER kb
# Run the migrations
CMD ["alembic", "-c", "/app/db/migrations/alembic.ini", "upgrade", "head"]

View File

@ -13,13 +13,12 @@ RUN apt-get update && apt-get install -y \
# libreoffice-writer \ # libreoffice-writer \
&& apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* && apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
COPY requirements-*.txt ./ COPY requirements ./requirements/
RUN pip install --no-cache-dir -r requirements-common.txt COPY setup.py ./
RUN pip install --no-cache-dir -r requirements-parsers.txt RUN mkdir src
RUN pip install --no-cache-dir -r requirements-workers.txt RUN pip install -e ".[common]"
# Install Python dependencies # Install Python dependencies
COPY setup.py ./
COPY src/ ./src/ COPY src/ ./src/
RUN pip install -e ".[workers]" RUN pip install -e ".[workers]"

View File

@ -1,7 +1,9 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -euo pipefail set -euo pipefail
QUEUE_PREFIX=${QUEUE_PREFIX:-memory}
QUEUES=${QUEUES:-default} QUEUES=${QUEUES:-default}
QUEUES=$(IFS=,; echo "${QUEUES}" | tr ',' '\n' | sed "s/^/${QUEUE_PREFIX}-/" | paste -sd, -)
CONCURRENCY=${CONCURRENCY:-2} CONCURRENCY=${CONCURRENCY:-2}
LOGLEVEL=${LOGLEVEL:-INFO} LOGLEVEL=${LOGLEVEL:-INFO}

View File

@ -1,5 +0,0 @@
PyMuPDF==1.25.5
ebooklib==0.18.0
beautifulsoup4==4.13.4
markdownify==0.13.1
pillow==10.4.0

View File

@ -1,5 +0,0 @@
openai==1.25.0
pillow==10.4.0
pypandoc==1.15.0
beautifulsoup4==4.13.4
feedparser==6.0.10

View File

@ -2,6 +2,6 @@ fastapi==0.112.2
uvicorn==0.29.0 uvicorn==0.29.0
python-jose==3.3.0 python-jose==3.3.0
python-multipart==0.0.9 python-multipart==0.0.9
sqladmin sqladmin==0.20.1
mcp==1.9.2 mcp==1.9.2
bm25s[full]==0.2.13 bm25s[full]==0.2.13

View File

@ -1,11 +1,12 @@
sqlalchemy==2.0.30 sqlalchemy==2.0.30
psycopg2-binary==2.9.9 psycopg2-binary==2.9.9
pydantic==2.7.1 pydantic==2.7.2
alembic==1.13.1 alembic==1.13.1
dotenv==0.9.9 dotenv==0.9.9
voyageai==0.3.2 voyageai==0.3.2
qdrant-client==1.9.0 qdrant-client==1.9.0
anthropic==0.18.1 anthropic==0.18.1
openai==1.25.0
# Pin the httpx version, as newer versions break the anthropic client # Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0 httpx==0.27.0
celery==5.3.6 celery[sqs]==5.3.6

View File

@ -0,0 +1,7 @@
PyMuPDF==1.25.5
EbookLib==0.18
beautifulsoup4==4.13.4
markdownify==0.13.1
pillow==10.3.0
pypandoc==1.15
feedparser==6.0.10

View File

@ -4,7 +4,7 @@ from setuptools import setup, find_namespace_packages
def read_requirements(filename: str) -> list[str]: def read_requirements(filename: str) -> list[str]:
"""Read requirements from file, ignoring comments and -r directives.""" """Read requirements from file, ignoring comments and -r directives."""
path = pathlib.Path(filename) path = pathlib.Path(__file__).parent / "requirements" / filename
return [ return [
line.strip() line.strip()
for line in path.read_text().splitlines() for line in path.read_text().splitlines()
@ -16,7 +16,6 @@ def read_requirements(filename: str) -> list[str]:
common_requires = read_requirements("requirements-common.txt") common_requires = read_requirements("requirements-common.txt")
parsers_requires = read_requirements("requirements-parsers.txt") parsers_requires = read_requirements("requirements-parsers.txt")
api_requires = read_requirements("requirements-api.txt") api_requires = read_requirements("requirements-api.txt")
workers_requires = read_requirements("requirements-workers.txt")
dev_requires = read_requirements("requirements-dev.txt") dev_requires = read_requirements("requirements-dev.txt")
setup( setup(
@ -26,14 +25,9 @@ setup(
packages=find_namespace_packages(where="src"), packages=find_namespace_packages(where="src"),
python_requires=">=3.10", python_requires=">=3.10",
extras_require={ extras_require={
"api": api_requires + common_requires, "api": api_requires + common_requires + parsers_requires,
"workers": workers_requires + common_requires + parsers_requires, "common": common_requires + parsers_requires,
"common": common_requires,
"dev": dev_requires, "dev": dev_requires,
"all": api_requires "all": api_requires + common_requires + dev_requires + parsers_requires,
+ workers_requires
+ common_requires
+ dev_requires
+ parsers_requires,
}, },
) )

View File

@ -182,6 +182,24 @@ async def observe(
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions. Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
Be specific and detailed - observations should make sense months later. Be specific and detailed - observations should make sense months later.
Example call:
```
{
"observations": [
{
"content": "The user is a software engineer.",
"subject": "user",
"observation_type": "belief",
"confidences": {"observation_accuracy": 0.9},
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
"tags": ["programming", "work"]
}
],
"session_id": "123e4567-e89b-12d3-a456-426614174000",
"agent_model": "gpt-4o"
}
```
RawObservation fields: RawObservation fields:
content (required): Detailed observation text explaining what you observed content (required): Detailed observation text explaining what you observed
subject (required): Consistent identifier like "programming_style", "work_habits" subject (required): Consistent identifier like "programming_style", "work_habits"
@ -200,7 +218,7 @@ async def observe(
observation, observation,
celery_app.send_task( celery_app.send_task(
SYNC_OBSERVATION, SYNC_OBSERVATION,
queue="notes", queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={ kwargs={
"subject": observation.subject, "subject": observation.subject,
"content": observation.content, "content": observation.content,
@ -323,7 +341,7 @@ async def create_note(
try: try:
task = celery_app.send_task( task = celery_app.send_task(
SYNC_NOTE, SYNC_NOTE,
queue="notes", queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={ kwargs={
"subject": subject, "subject": subject,
"content": content, "content": content,

View File

@ -20,6 +20,7 @@ from memory.common.db.models import (
ForumPost, ForumPost,
AgentObservation, AgentObservation,
Note, Note,
User,
) )
@ -207,6 +208,15 @@ class NoteAdmin(ModelView, model=Note):
column_sortable_list = ["inserted_at"] column_sortable_list = ["inserted_at"]
class UserAdmin(ModelView, model=User):
column_list = [
"id",
"email",
"name",
"created_at",
]
def setup_admin(admin: Admin): def setup_admin(admin: Admin):
"""Add all admin views to the admin instance.""" """Add all admin views to the admin instance."""
admin.add_view(SourceItemAdmin) admin.add_view(SourceItemAdmin)
@ -224,3 +234,4 @@ def setup_admin(admin: Admin):
admin.add_view(ForumPostAdmin) admin.add_view(ForumPostAdmin)
admin.add_view(ComicAdmin) admin.add_view(ComicAdmin)
admin.add_view(PhotoAdmin) admin.add_view(PhotoAdmin)
admin.add_view(UserAdmin)

View File

@ -8,16 +8,30 @@ import pathlib
import logging import logging
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import FastAPI, HTTPException, File, UploadFile, Query, Form from fastapi import (
FastAPI,
HTTPException,
File,
UploadFile,
Query,
Form,
Depends,
)
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from sqladmin import Admin from sqladmin import Admin
from memory.common import settings from memory.common import settings
from memory.common import extract from memory.common import extract
from memory.common.db.connection import get_engine from memory.common.db.connection import get_engine
from memory.common.db.models import User
from memory.api.admin import setup_admin from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult from memory.api.search import search, SearchResult
from memory.api.MCP.tools import mcp from memory.api.MCP.tools import mcp
from memory.api.auth import (
router as auth_router,
get_current_user,
AuthenticationMiddleware,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,6 +49,10 @@ app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
engine = get_engine() engine = get_engine()
admin = Admin(app, engine) admin = Admin(app, engine)
setup_admin(admin) setup_admin(admin)
# Include auth router
app.add_middleware(AuthenticationMiddleware)
app.include_router(auth_router)
app.mount("/", mcp.streamable_http_app()) app.mount("/", mcp.streamable_http_app())
@ -63,6 +81,7 @@ async def search_endpoint(
limit: int = Query(10, ge=1, le=100), limit: int = Query(10, ge=1, le=100),
min_text_score: float = Query(0.3, ge=0.0, le=1.0), min_text_score: float = Query(0.3, ge=0.0, le=1.0),
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0), min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
current_user: User = Depends(get_current_user),
): ):
"""Search endpoint - delegates to search module""" """Search endpoint - delegates to search module"""
upload_data = [ upload_data = [
@ -74,7 +93,7 @@ async def search_endpoint(
return await search( return await search(
upload_data, upload_data,
previews=previews, previews=previews,
modalities=modalities, modalities=set(modalities),
limit=limit, limit=limit,
min_text_score=min_text_score, min_text_score=min_text_score,
min_multimodal_score=min_multimodal_score, min_multimodal_score=min_multimodal_score,
@ -82,7 +101,7 @@ async def search_endpoint(
@app.get("/files/{path:path}") @app.get("/files/{path:path}")
def get_file_by_path(path: str): def get_file_by_path(path: str, current_user: User = Depends(get_current_user)):
""" """
Fetch a file by its path Fetch a file by its path

228
src/memory/api/auth.py Normal file
View File

@ -0,0 +1,228 @@
from datetime import datetime, timedelta, timezone
import textwrap
from typing import cast
import logging
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
from fastapi.responses import HTMLResponse
from starlette.middleware.base import BaseHTTPMiddleware
from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session
from pydantic import BaseModel
from memory.common.db.connection import get_session, make_session
from memory.common.db.models.users import User, UserSession
logger = logging.getLogger(__name__)
# Pydantic models
class LoginRequest(BaseModel):
email: str
password: str
class RegisterRequest(BaseModel):
email: str
password: str
name: str
class LoginResponse(BaseModel):
session_id: str
user_id: int
email: str
name: str
# Create router
router = APIRouter(prefix="/auth", tags=["auth"])
def create_user_session(
user_id: int, db: DBSession, valid_for: int = settings.SESSION_VALID_FOR
) -> str:
"""Create a new session for a user"""
expires_at = datetime.now(timezone.utc) + timedelta(days=valid_for)
session = UserSession(user_id=user_id, expires_at=expires_at)
db.add(session)
db.commit()
return str(session.id)
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
"""Get user from session ID if session is valid"""
session = db.query(UserSession).get(session_id)
if not session:
return None
now = datetime.now(timezone.utc)
if session.expires_at.replace(tzinfo=timezone.utc) > now:
return session.user
return None
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
"""FastAPI dependency to get current authenticated user"""
# Check for session ID in header or cookie
session_id = request.headers.get(
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
raise HTTPException(status_code=401, detail="No session provided")
user = get_session_user(session_id, db)
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired session")
return user
def create_user(email: str, password: str, name: str, db: DBSession) -> User:
"""Create a new user"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == email).first()
if existing_user:
raise HTTPException(status_code=400, detail="User already exists")
user = User.create_with_password(email, name, password)
db.add(user)
db.commit()
db.refresh(user)
return user
def authenticate_user(email: str, password: str, db: DBSession) -> User | None:
"""Authenticate a user by email and password"""
user = db.query(User).filter(User.email == email).first()
if user and user.is_valid_password(password):
return user
return None
# Auth endpoints
@router.post("/register", response_model=LoginResponse)
def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
"""Register a new user"""
if not settings.REGISTER_ENABLED:
raise HTTPException(status_code=403, detail="Registration is disabled")
user = create_user(request.email, request.password, request.name, db)
session_id = create_user_session(user.id, db) # type: ignore
return LoginResponse(session_id=session_id, **user.serialize())
@router.get("/login", response_model=LoginResponse)
def login_page():
"""Login page"""
return HTMLResponse(
content=textwrap.dedent("""
<html>
<body>
<h1>Login</h1>
<form method="post" action="/auth/login-form">
<input type="email" name="email" placeholder="Email" />
<input type="password" name="password" placeholder="Password" />
<button type="submit">Login</button>
</form>
</body>
</html>
"""),
)
@router.post("/login", response_model=LoginResponse)
def login(
request: LoginRequest, response: Response, db: DBSession = Depends(get_session)
):
"""Login and create a session"""
return login_form(response, db, request.email, request.password)
@router.post("/login-form", response_model=LoginResponse)
def login_form(
response: Response,
db: DBSession = Depends(get_session),
email: str = Form(),
password: str = Form(),
):
"""Login with form data and create a session"""
user = authenticate_user(email, password, db)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
session_id = create_user_session(cast(int, user.id), db)
# Set session cookie
response.set_cookie(
key=settings.SESSION_COOKIE_NAME,
value=session_id,
httponly=True,
secure=settings.HTTPS,
samesite="lax",
max_age=settings.SESSION_COOKIE_MAX_AGE,
)
return LoginResponse(session_id=session_id, **user.serialize())
@router.post("/logout")
def logout(response: Response, user: User = Depends(get_current_user)):
"""Logout and clear session"""
response.delete_cookie(settings.SESSION_COOKIE_NAME)
return {"message": "Logged out successfully"}
@router.get("/me")
def get_me(user: User = Depends(get_current_user)):
"""Get current user info"""
return user.serialize()
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Middleware to require authentication for all endpoints except whitelisted ones."""
# Endpoints that don't require authentication
WHITELIST = {
"/health",
"/auth/login",
"/auth/login-form",
"/auth/register",
}
async def dispatch(self, request: Request, call_next):
path = request.url.path
# Skip authentication for whitelisted endpoints
if any(path.startswith(whitelist_path) for whitelist_path in self.WHITELIST):
return await call_next(request)
# Check for session ID in header or cookie
session_id = request.headers.get(
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
return Response(
content="Authentication required",
status_code=401,
headers={"WWW-Authenticate": "Bearer"},
)
# Validate session and get user
with make_session() as session:
user = get_session_user(session_id, session)
if not user:
return Response(
content="Invalid or expired session",
status_code=401,
headers={"WWW-Authenticate": "Bearer"},
)
logger.debug(f"Authenticated request from user {user.email} to {path}")
return await call_next(request)

View File

@ -1,4 +1,5 @@
from celery import Celery from celery import Celery
from kombu.utils.url import safequote
from memory.common import settings from memory.common import settings
EMAIL_ROOT = "memory.workers.tasks.email" EMAIL_ROOT = "memory.workers.tasks.email"
@ -41,13 +42,17 @@ SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive" SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
def rabbit_url() -> str: def get_broker_url() -> str:
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:{settings.RABBITMQ_PORT}//" protocol = settings.CELERY_BROKER_TYPE
user = safequote(settings.CELERY_BROKER_USER)
password = safequote(settings.CELERY_BROKER_PASSWORD)
host = settings.CELERY_BROKER_HOST
return f"{protocol}://{user}:{password}@{host}"
app = Celery( app = Celery(
"memory", "memory",
broker=rabbit_url(), broker=get_broker_url(),
backend=settings.CELERY_RESULT_BACKEND, backend=settings.CELERY_RESULT_BACKEND,
) )
@ -59,20 +64,22 @@ app.conf.update(
task_reject_on_worker_lost=True, task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1, worker_prefetch_multiplier=1,
task_routes={ task_routes={
f"{EMAIL_ROOT}.*": {"queue": "email"}, f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"}, f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
f"{COMIC_ROOT}.*": {"queue": "comic"}, f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
f"{EBOOK_ROOT}.*": {"queue": "ebooks"}, f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
f"{BLOGS_ROOT}.*": {"queue": "blogs"}, f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
f"{FORUMS_ROOT}.*": {"queue": "forums"}, f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"}, f"{MAINTENANCE_ROOT}.*": {
f"{NOTES_ROOT}.*": {"queue": "notes"}, "queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
f"{OBSERVATIONS_ROOT}.*": {"queue": "notes"}, },
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
}, },
) )
@app.on_after_configure.connect # type: ignore @app.on_after_configure.connect # type: ignore[attr-defined]
def ensure_qdrant_initialised(sender, **_): def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant from memory.common import qdrant

View File

@ -3,8 +3,9 @@ Database connection utilities.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session, Session
from memory.common import settings from memory.common import settings
@ -23,9 +24,17 @@ def get_session_factory():
def get_scoped_session(): def get_scoped_session():
"""Create a thread-local scoped session factory""" """Create a thread-local scoped session factory"""
engine = get_engine() return scoped_session(get_session_factory())
session_factory = sessionmaker(bind=engine)
return scoped_session(session_factory)
def get_session() -> Generator[Session, None, None]:
"""FastAPI dependency for database sessions"""
session_factory = get_session_factory()
session = session_factory()
try:
yield session
finally:
session.close()
@contextmanager @contextmanager

View File

@ -32,6 +32,10 @@ from memory.common.db.models.sources import (
ArticleFeed, ArticleFeed,
EmailAccount, EmailAccount,
) )
from memory.common.db.models.users import (
User,
UserSession,
)
__all__ = [ __all__ = [
"Base", "Base",
@ -62,4 +66,7 @@ __all__ = [
"Book", "Book",
"ArticleFeed", "ArticleFeed",
"EmailAccount", "EmailAccount",
# Users
"User",
"UserSession",
] ]

View File

@ -0,0 +1,65 @@
import hashlib
import secrets
from typing import cast
import uuid
from memory.common.db.models.base import Base
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
def hash_password(password: str) -> str:
"""Hash a password using SHA-256 with salt"""
salt = secrets.token_hex(16)
return f"{salt}:{hashlib.sha256((salt + password).encode()).hexdigest()}"
def verify_password(password: str, password_hash: str) -> bool:
"""Verify a password against its hash"""
try:
salt, hash_value = password_hash.split(":", 1)
return hashlib.sha256((salt + password).encode()).hexdigest() == hash_value
except ValueError:
return False
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False)
email = Column(String, nullable=False, unique=True)
password_hash = Column(String, nullable=False)
# Relationship to sessions
sessions = relationship(
"UserSession", back_populates="user", cascade="all, delete-orphan"
)
def serialize(self) -> dict:
return {
"user_id": self.id,
"name": self.name,
"email": self.email,
}
def is_valid_password(self, password: str) -> bool:
"""Check if the provided password is valid for this user"""
return verify_password(password, cast(str, self.password_hash))
@classmethod
def create_with_password(cls, email: str, name: str, password: str) -> "User":
"""Create a new user with a hashed password"""
return cls(email=email, name=name, password_hash=hash_password(password))
class UserSession(Base):
__tablename__ = "user_sessions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False)
# Relationship to user
user = relationship("User", back_populates="sessions")

View File

@ -29,11 +29,18 @@ def make_db_url(
DB_URL = os.getenv("DATABASE_URL", make_db_url()) DB_URL = os.getenv("DATABASE_URL", make_db_url())
# Celery settings
RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb") # Broker settings
RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb") CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq") CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "amqp").lower() # amqp or sqs
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672") CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "kb")
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", "kb")
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
if not CELERY_BROKER_HOST and CELERY_BROKER_TYPE == "amqp":
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}") CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
@ -122,3 +129,12 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
else: else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307") SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
# API settings
HTTPS = boolean_env("HTTPS", False)
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True

20
tools/add_user.py Normal file
View File

@ -0,0 +1,20 @@
#! /usr/bin/env python
import argparse
from memory.common.db.connection import make_session
from memory.common.db.models.users import User
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--email", type=str, required=True)
args.add_argument("--password", type=str, required=True)
args.add_argument("--name", type=str, required=True)
args = args.parse_args()
with make_session() as session:
user = User(email=args.email, password=args.password, name=args.name)
session.add(user)
session.commit()
print(f"User {args.email} created")

120
tools/simple_proxy.py Normal file
View File

@ -0,0 +1,120 @@
#!/usr/bin/env python3
import argparse
import httpx
import uvicorn
from pydantic import BaseModel
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import Response
class State(BaseModel):
email: str
password: str
remote_server: str
session_header: str = "X-Session-ID"
session_id: str | None = None
port: int = 8080
def parse_args() -> State:
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Simple HTTP proxy with authentication"
)
parser.add_argument("--remote-server", required=True, help="Remote server URL")
parser.add_argument("--email", required=True, help="Email for authentication")
parser.add_argument("--password", required=True, help="Password for authentication")
parser.add_argument(
"--session-header", default="X-Session-ID", help="Session header name"
)
parser.add_argument("--port", type=int, default=8080, help="Port to run proxy on")
return State(**vars(parser.parse_args()))
state = parse_args()
async def login() -> None:
"""Login to remote server and store session ID"""
login_url = f"{state.remote_server}/auth/login"
login_data = {"email": state.email, "password": state.password}
async with httpx.AsyncClient() as client:
try:
response = await client.post(login_url, json=login_data)
response.raise_for_status()
login_response = response.json()
state.session_id = login_response["session_id"]
print(f"Successfully logged in, session ID: {state.session_id}")
except httpx.HTTPStatusError as e:
print(
f"Login failed with status {e.response.status_code}: {e.response.text}"
)
raise
except Exception as e:
print(f"Login failed: {e}")
raise
async def proxy_request(request: Request) -> Response:
"""Proxy request to remote server with session header"""
if not state.session_id:
try:
await login()
except Exception as e:
print(f"Login failed: {e}")
raise HTTPException(status_code=401, detail="Unauthorized")
# Build the target URL
target_url = f"{state.remote_server}{request.url.path}"
if request.url.query:
target_url += f"?{request.url.query}"
# Get request body
body = await request.body()
headers = dict(request.headers)
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method=request.method,
url=target_url,
headers=headers | {state.session_header: state.session_id}, # type: ignore
content=body,
timeout=30.0,
)
# Forward response
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.RequestError as e:
print(f"Request failed: {e}")
raise HTTPException(status_code=502, detail=f"Proxy request failed: {e}")
# Create FastAPI app
app = FastAPI(title="Simple Proxy")
@app.api_route(
"/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
)
async def proxy_all(request: Request):
"""Proxy all requests to remote server"""
return await proxy_request(request)
if __name__ == "__main__":
print(f"Starting proxy server on port {state.port}")
print(f"Proxying to: {state.remote_server}")
uvicorn.run(app, host="0.0.0.0", port=state.port)