add PRs and People

This commit is contained in:
mruwnik 2025-12-24 13:25:34 +00:00
parent efbc469ee3
commit 47629fc5fb
23 changed files with 4902 additions and 94 deletions

4
.gitignore vendored
View File

@ -1,5 +1,9 @@
Books
books.md
clean_books
scripts
CLAUDE.md
memory_files
venv

View File

@ -0,0 +1,56 @@
"""Add people tracking
Revision ID: c9d0e1f2a3b4
Revises: b7c8d9e0f1a2
Create Date: 2025-12-24 12:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "c9d0e1f2a3b4"
down_revision: Union[str, None] = "b7c8d9e0f1a2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"people",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("identifier", sa.Text(), nullable=False),
sa.Column("display_name", sa.Text(), nullable=False),
sa.Column(
"aliases",
postgresql.ARRAY(sa.Text()),
server_default="{}",
nullable=False,
),
sa.Column(
"contact_info",
postgresql.JSONB(astext_type=sa.Text()),
server_default="{}",
nullable=False,
),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("identifier"),
)
op.create_index("person_identifier_idx", "people", ["identifier"], unique=False)
op.create_index("person_display_name_idx", "people", ["display_name"], unique=False)
op.create_index(
"person_aliases_idx", "people", ["aliases"], unique=False, postgresql_using="gin"
)
def downgrade() -> None:
op.drop_index("person_aliases_idx", table_name="people")
op.drop_index("person_display_name_idx", table_name="people")
op.drop_index("person_identifier_idx", table_name="people")
op.drop_table("people")

View File

@ -0,0 +1,57 @@
"""Add github_pr_data table for PR-specific data
Revision ID: d0e1f2a3b4c5
Revises: c9d0e1f2a3b4
Create Date: 2025-12-24 15:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "d0e1f2a3b4c5"
down_revision: Union[str, None] = "c9d0e1f2a3b4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"github_pr_data",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("github_item_id", sa.BigInteger(), nullable=False),
# Diff stored compressed with zlib
sa.Column("diff_compressed", sa.LargeBinary(), nullable=True),
# File changes as structured data
# [{filename, status, additions, deletions, patch?}]
sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
# Stats
sa.Column("additions", sa.Integer(), nullable=True),
sa.Column("deletions", sa.Integer(), nullable=True),
sa.Column("changed_files_count", sa.Integer(), nullable=True),
# Reviews - [{user, state, body, submitted_at}]
sa.Column("reviews", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
# Review comments (line-by-line code comments)
# [{user, body, path, line, diff_hunk, created_at}]
sa.Column(
"review_comments", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.ForeignKeyConstraint(
["github_item_id"], ["github_item.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("github_item_id"),
)
op.create_index(
"github_pr_data_item_idx", "github_pr_data", ["github_item_id"], unique=True
)
def downgrade() -> None:
op.drop_index("github_pr_data_item_idx", table_name="github_pr_data")
op.drop_table("github_pr_data")

View File

@ -206,7 +206,7 @@ services:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance,notes,scheduler"
QUEUES: "backup,blogs,comic,discord,ebooks,email,forums,github,people,photo_embed,maintenance,notes,scheduler"
ingest-hub:
<<: *worker-base

View File

@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
git config --global user.name "${GIT_USER_NAME}"
# Default queues to process
ENV QUEUES="backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance"
ENV QUEUES="backup,blogs,comic,discord,ebooks,email,forums,github,people,photo_embed,maintenance"
ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"]

View File

@ -4,3 +4,5 @@ import memory.api.MCP.metadata
import memory.api.MCP.schedules
import memory.api.MCP.books
import memory.api.MCP.manifest
import memory.api.MCP.github
import memory.api.MCP.people

View File

@ -0,0 +1,518 @@
"""MCP tools for GitHub issue tracking and management."""
import logging
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import Text, case, desc, func, asc
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.api.search.search import search
from memory.api.search.types import SearchConfig, SearchFilters
from memory.common import extract
from memory.common.db.connection import make_session
from memory.common.db.models import GithubItem
logger = logging.getLogger(__name__)
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
"""Build GitHub URL from repo path and issue/PR number."""
if number is None:
return f"https://github.com/{repo_path}"
url_type = "pull" if kind == "pr" else "issues"
return f"https://github.com/{repo_path}/{url_type}/{number}"
def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[str, Any]:
"""Serialize a GithubItem to a dict for API response."""
result = {
"id": item.id,
"number": item.number,
"kind": item.kind,
"repo_path": item.repo_path,
"title": item.title,
"state": item.state,
"author": item.author,
"assignees": item.assignees or [],
"labels": item.labels or [],
"milestone": item.milestone,
"project_status": item.project_status,
"project_priority": item.project_priority,
"project_fields": item.project_fields,
"comment_count": item.comment_count,
"created_at": item.created_at.isoformat() if item.created_at else None,
"closed_at": item.closed_at.isoformat() if item.closed_at else None,
"merged_at": item.merged_at.isoformat() if item.merged_at else None,
"github_updated_at": (
item.github_updated_at.isoformat() if item.github_updated_at else None
),
"url": _build_github_url(item.repo_path, item.number, item.kind),
}
if include_content:
result["content"] = item.content
# Include PR-specific data if available
if item.kind == "pr" and item.pr_data:
result["pr_data"] = {
"additions": item.pr_data.additions,
"deletions": item.pr_data.deletions,
"changed_files_count": item.pr_data.changed_files_count,
"files": item.pr_data.files,
"reviews": item.pr_data.reviews,
"review_comments": item.pr_data.review_comments,
"diff": item.pr_data.diff, # Decompressed via property
}
return result
@mcp.tool()
async def list_github_issues(
repo: str | None = None,
assignee: str | None = None,
author: str | None = None,
state: str | None = None,
kind: str | None = None,
labels: list[str] | None = None,
project_status: str | None = None,
project_field: dict[str, str] | None = None,
updated_since: str | None = None,
updated_before: str | None = None,
limit: int = 50,
order_by: str = "updated",
) -> list[dict]:
"""
List GitHub issues and PRs with flexible filtering.
Use for daily triage, finding assigned issues, tracking stale issues, etc.
Args:
repo: Filter by repository path (e.g., "owner/name")
assignee: Filter by assignee username
author: Filter by author username
state: Filter by state: "open", "closed", "merged" (default: all)
kind: Filter by type: "issue" or "pr" (default: both)
labels: Filter by GitHub labels (matches ANY label in list)
project_status: Filter by project status (e.g., "In Progress", "Backlog")
project_field: Filter by project field values (e.g., {"EquiStamp.Client": "Redwood"})
updated_since: ISO date - only issues updated after this time
updated_before: ISO date - only issues updated before this (for finding stale issues)
limit: Maximum results (default 50, max 200)
order_by: Sort order: "updated", "created", or "number" (default: "updated" descending)
Returns: List of issues with id, number, title, state, assignees, labels, project_fields, timestamps, url
"""
logger.info(f"list_github_issues called: repo={repo}, assignee={assignee}, state={state}")
limit = min(limit, 200)
with make_session() as session:
query = session.query(GithubItem)
# Apply filters
if repo:
query = query.filter(GithubItem.repo_path == repo)
if assignee:
query = query.filter(GithubItem.assignees.any(assignee))
if author:
query = query.filter(GithubItem.author == author)
if state:
query = query.filter(GithubItem.state == state)
if kind:
query = query.filter(GithubItem.kind == kind)
else:
# Exclude comments by default, only show issues and PRs
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
if labels:
# Match any label in the list using PostgreSQL array overlap
query = query.filter(
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
)
if project_status:
query = query.filter(GithubItem.project_status == project_status)
if project_field:
for key, value in project_field.items():
query = query.filter(
GithubItem.project_fields[key].astext == value
)
if updated_since:
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at >= since_dt)
if updated_before:
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at <= before_dt)
# Apply ordering
if order_by == "created":
query = query.order_by(desc(GithubItem.created_at))
elif order_by == "number":
query = query.order_by(desc(GithubItem.number))
else: # default: updated
query = query.order_by(desc(GithubItem.github_updated_at))
query = query.limit(limit)
items = query.all()
return [_serialize_issue(item) for item in items]
@mcp.tool()
async def search_github_issues(
query: str,
repo: str | None = None,
state: str | None = None,
kind: str | None = None,
limit: int = 20,
) -> list[dict]:
"""
Search GitHub issues using natural language.
Searches across issue titles, bodies, and comments.
Args:
query: Natural language search query (e.g., "authentication bug", "database migration")
repo: Optional filter by repository path
state: Optional filter: "open", "closed", "merged"
kind: Optional filter: "issue" or "pr"
limit: Maximum results (default 20, max 100)
Returns: List of matching issues with search score
"""
logger.info(f"search_github_issues called: query={query}, repo={repo}")
limit = min(limit, 100)
# Pre-filter source_ids if repo/state/kind filters are specified
source_ids = None
if repo or state or kind:
with make_session() as session:
q = session.query(GithubItem.id)
if repo:
q = q.filter(GithubItem.repo_path == repo)
if state:
q = q.filter(GithubItem.state == state)
if kind:
q = q.filter(GithubItem.kind == kind)
else:
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
source_ids = [item.id for item in q.all()]
# Use the existing search infrastructure
data = extract.extract_text(query, skip_summary=True)
config = SearchConfig(limit=limit, previews=True)
filters = SearchFilters()
if source_ids is not None:
filters["source_ids"] = source_ids
results = await search(
data,
modalities={"github"},
filters=filters,
config=config,
)
# Fetch full issue details for the results
output = []
with make_session() as session:
for result in results:
item = session.get(GithubItem, result.id)
if item:
serialized = _serialize_issue(item)
serialized["search_score"] = result.score
output.append(serialized)
return output
@mcp.tool()
async def github_issue_details(
repo: str,
number: int,
) -> dict:
"""
Get full details of a specific GitHub issue or PR including all comments.
Args:
repo: Repository path (e.g., "owner/name")
number: Issue or PR number
Returns: Full issue details including content (body + comments), project fields, timestamps.
For PRs, also includes pr_data with: diff (full), files changed, reviews, review comments.
"""
logger.info(f"github_issue_details called: repo={repo}, number={number}")
with make_session() as session:
item = (
session.query(GithubItem)
.filter(
GithubItem.repo_path == repo,
GithubItem.number == number,
GithubItem.kind.in_(["issue", "pr"]),
)
.first()
)
if not item:
raise ValueError(f"Issue #{number} not found in {repo}")
return _serialize_issue(item, include_content=True)
@mcp.tool()
async def github_work_summary(
since: str,
until: str | None = None,
group_by: str = "client",
repo: str | None = None,
) -> dict:
"""
Summarize GitHub work activity for billing and time tracking.
Groups issues by client, author, status, or repository.
Args:
since: ISO date - start of period (e.g., "2025-12-16")
until: ISO date - end of period (default: now)
group_by: How to group results: "client", "status", "author", "repo", "task_type"
repo: Optional filter by repository path
Returns: Summary with grouped counts and sample issues for each group
"""
logger.info(f"github_work_summary called: since={since}, group_by={group_by}")
since_dt = datetime.fromisoformat(since.replace("Z", "+00:00"))
if until:
until_dt = datetime.fromisoformat(until.replace("Z", "+00:00"))
else:
until_dt = datetime.now(timezone.utc)
# Map group_by to SQL expression
group_mappings = {
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
"status": GithubItem.project_status,
"author": GithubItem.author,
"repo": GithubItem.repo_path,
"task_type": GithubItem.project_fields["EquiStamp.Task Type"].astext,
}
if group_by not in group_mappings:
raise ValueError(
f"Invalid group_by: {group_by}. Must be one of: {list(group_mappings.keys())}"
)
group_col = group_mappings[group_by]
with make_session() as session:
# Build base query for the period
base_query = session.query(GithubItem).filter(
GithubItem.github_updated_at >= since_dt,
GithubItem.github_updated_at <= until_dt,
GithubItem.kind.in_(["issue", "pr"]),
)
if repo:
base_query = base_query.filter(GithubItem.repo_path == repo)
# Get aggregated counts by group
agg_query = (
session.query(
group_col.label("group_name"),
func.count(GithubItem.id).label("total"),
func.count(case((GithubItem.kind == "issue", 1))).label("issue_count"),
func.count(case((GithubItem.kind == "pr", 1))).label("pr_count"),
func.count(
case((GithubItem.state.in_(["closed", "merged"]), 1))
).label("closed_count"),
)
.filter(
GithubItem.github_updated_at >= since_dt,
GithubItem.github_updated_at <= until_dt,
GithubItem.kind.in_(["issue", "pr"]),
)
.group_by(group_col)
.order_by(desc("total"))
)
if repo:
agg_query = agg_query.filter(GithubItem.repo_path == repo)
groups = agg_query.all()
# Build summary with sample issues for each group
summary = []
total_issues = 0
total_prs = 0
for group_name, total, issue_count, pr_count, closed_count in groups:
if group_name is None:
group_name = "(unset)"
total_issues += issue_count
total_prs += pr_count
# Get sample issues for this group
sample_query = base_query.filter(group_col == group_name).limit(5)
samples = [
{
"number": item.number,
"title": item.title,
"repo_path": item.repo_path,
"state": item.state,
"url": _build_github_url(item.repo_path, item.number, item.kind),
}
for item in sample_query.all()
]
summary.append(
{
"group": group_name,
"total": total,
"issue_count": issue_count,
"pr_count": pr_count,
"closed_count": closed_count,
"issues": samples,
}
)
return {
"period": {
"since": since_dt.isoformat(),
"until": until_dt.isoformat(),
},
"group_by": group_by,
"summary": summary,
"total_issues": total_issues,
"total_prs": total_prs,
}
@mcp.tool()
async def github_repo_overview(
repo: str,
) -> dict:
"""
Get an overview of a GitHub repository's issues and PRs.
Shows counts, status breakdown, top assignees, and labels.
Args:
repo: Repository path (e.g., "EquiStamp/equistamp" or "owner/name")
Returns: Repository statistics including counts, status breakdown, top assignees, labels
"""
logger.info(f"github_repo_overview called: repo={repo}")
with make_session() as session:
# Base query for this repo
base_query = session.query(GithubItem).filter(
GithubItem.repo_path == repo,
GithubItem.kind.in_(["issue", "pr"]),
)
# Get total counts
counts_query = session.query(
func.count(GithubItem.id).label("total"),
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
func.count(
case(((GithubItem.kind == "issue") & (GithubItem.state == "open"), 1))
).label("open_issues"),
func.count(
case(((GithubItem.kind == "issue") & (GithubItem.state == "closed"), 1))
).label("closed_issues"),
func.count(case((GithubItem.kind == "pr", 1))).label("total_prs"),
func.count(
case(((GithubItem.kind == "pr") & (GithubItem.state == "open"), 1))
).label("open_prs"),
func.count(
case(((GithubItem.kind == "pr") & (GithubItem.merged_at.isnot(None)), 1))
).label("merged_prs"),
func.max(GithubItem.github_updated_at).label("last_updated"),
).filter(
GithubItem.repo_path == repo,
GithubItem.kind.in_(["issue", "pr"]),
)
counts = counts_query.first()
# Status breakdown (for project_status)
status_query = (
session.query(
GithubItem.project_status.label("status"),
func.count(GithubItem.id).label("count"),
)
.filter(
GithubItem.repo_path == repo,
GithubItem.kind.in_(["issue", "pr"]),
GithubItem.project_status.isnot(None),
)
.group_by(GithubItem.project_status)
.order_by(desc("count"))
)
status_breakdown = {row.status: row.count for row in status_query.all()}
# Top assignees (open issues only)
assignee_query = (
session.query(
func.unnest(GithubItem.assignees).label("assignee"),
func.count(GithubItem.id).label("count"),
)
.filter(
GithubItem.repo_path == repo,
GithubItem.kind.in_(["issue", "pr"]),
GithubItem.state == "open",
)
.group_by("assignee")
.order_by(desc("count"))
.limit(10)
)
top_assignees = [
{"username": row.assignee, "open_count": row.count}
for row in assignee_query.all()
]
# Label counts
label_query = (
session.query(
func.unnest(GithubItem.labels).label("label"),
func.count(GithubItem.id).label("count"),
)
.filter(
GithubItem.repo_path == repo,
GithubItem.kind.in_(["issue", "pr"]),
)
.group_by("label")
.order_by(desc("count"))
.limit(20)
)
labels = {row.label: row.count for row in label_query.all()}
return {
"repo_path": repo,
"counts": {
"total": counts.total if counts else 0,
"total_issues": counts.total_issues if counts else 0,
"open_issues": counts.open_issues if counts else 0,
"closed_issues": counts.closed_issues if counts else 0,
"total_prs": counts.total_prs if counts else 0,
"open_prs": counts.open_prs if counts else 0,
"merged_prs": counts.merged_prs if counts else 0,
},
"status_breakdown": status_breakdown,
"top_assignees": top_assignees,
"labels": labels,
"last_updated": (
counts.last_updated.isoformat()
if counts and counts.last_updated
else None
),
}

View File

@ -0,0 +1,257 @@
"""
MCP tools for tracking people.
"""
import logging
from typing import Any
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import Person
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
from memory.common.celery_app import app as celery_app
from memory.common import settings
logger = logging.getLogger(__name__)
def _person_to_dict(person: Person) -> dict[str, Any]:
"""Convert a Person model to a dictionary for API responses."""
return {
"identifier": person.identifier,
"display_name": person.display_name,
"aliases": list(person.aliases or []),
"contact_info": dict(person.contact_info or {}),
"tags": list(person.tags or []),
"notes": person.content,
"created_at": person.inserted_at.isoformat() if person.inserted_at else None,
}
@mcp.tool()
async def add_person(
identifier: str,
display_name: str,
aliases: list[str] | None = None,
contact_info: dict | None = None,
tags: list[str] | None = None,
notes: str | None = None,
) -> dict:
"""
Add a new person to track.
Args:
identifier: Unique slug for the person (e.g., "alice_chen")
display_name: Human-readable name (e.g., "Alice Chen")
aliases: Alternative names/handles (e.g., ["@alice_c", "alice.chen@work.com"])
contact_info: Contact information as a dict (e.g., {"email": "...", "phone": "..."})
tags: Categorization tags (e.g., ["work", "friend", "climbing"])
notes: Free-form notes about the person
Returns:
Task status with task_id
Example:
add_person(
identifier="alice_chen",
display_name="Alice Chen",
aliases=["@alice_c"],
contact_info={"email": "alice@example.com"},
tags=["work", "engineering"],
notes="Tech lead on Platform team"
)
"""
logger.info(f"MCP: Adding person: {identifier}")
# Check if person already exists
with make_session() as session:
existing = session.query(Person).filter(Person.identifier == identifier).first()
if existing:
raise ValueError(f"Person with identifier '{identifier}' already exists")
task = celery_app.send_task(
SYNC_PERSON,
queue=f"{settings.CELERY_QUEUE_PREFIX}-people",
kwargs={
"identifier": identifier,
"display_name": display_name,
"aliases": aliases,
"contact_info": contact_info,
"tags": tags,
"notes": notes,
},
)
return {
"task_id": task.id,
"status": "queued",
"identifier": identifier,
}
@mcp.tool()
async def update_person_info(
identifier: str,
display_name: str | None = None,
aliases: list[str] | None = None,
contact_info: dict | None = None,
tags: list[str] | None = None,
notes: str | None = None,
replace_notes: bool = False,
) -> dict:
"""
Update information about a person with merge semantics.
This tool MERGES new information with existing data rather than replacing it:
- display_name: Replaces existing value
- aliases: Adds new aliases (union with existing)
- contact_info: Deep merges (adds new keys, updates existing keys, never deletes)
- tags: Adds new tags (union with existing)
- notes: Appends to existing notes (or replaces if replace_notes=True)
Args:
identifier: The person's unique identifier
display_name: New display name (replaces existing)
aliases: Additional aliases to add
contact_info: Additional contact info to merge
tags: Additional tags to add
notes: Notes to append (or replace if replace_notes=True)
replace_notes: If True, replace notes instead of appending
Returns:
Task status with task_id
Example:
# Add new contact info without losing existing data
update_person_info(
identifier="alice_chen",
contact_info={"phone": "555-1234"}, # Added to existing
notes="Enjoys rock climbing" # Appended to existing notes
)
"""
logger.info(f"MCP: Updating person: {identifier}")
# Verify person exists
with make_session() as session:
person = session.query(Person).filter(Person.identifier == identifier).first()
if not person:
raise ValueError(f"Person with identifier '{identifier}' not found")
task = celery_app.send_task(
UPDATE_PERSON,
queue=f"{settings.CELERY_QUEUE_PREFIX}-people",
kwargs={
"identifier": identifier,
"display_name": display_name,
"aliases": aliases,
"contact_info": contact_info,
"tags": tags,
"notes": notes,
"replace_notes": replace_notes,
},
)
return {
"task_id": task.id,
"status": "queued",
"identifier": identifier,
}
@mcp.tool()
async def get_person(identifier: str) -> dict | None:
"""
Get a person by their identifier.
Args:
identifier: The person's unique identifier
Returns:
The person record, or None if not found
"""
logger.info(f"MCP: Getting person: {identifier}")
with make_session() as session:
person = session.query(Person).filter(Person.identifier == identifier).first()
if not person:
return None
return _person_to_dict(person)
@mcp.tool()
async def list_people(
tags: list[str] | None = None,
search: str | None = None,
limit: int = 50,
) -> list[dict]:
"""
List all tracked people, optionally filtered by tags or search term.
Args:
tags: Filter to people with at least one of these tags
search: Search term to match against name, aliases, or notes
limit: Maximum number of results (default 50, max 200)
Returns:
List of person records matching the filters
"""
logger.info(f"MCP: Listing people (tags={tags}, search={search})")
limit = min(limit, 200)
with make_session() as session:
query = session.query(Person)
if tags:
query = query.filter(
Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
)
if search:
search_term = f"%{search.lower()}%"
query = query.filter(
(Person.display_name.ilike(search_term))
| (Person.content.ilike(search_term))
| (Person.identifier.ilike(search_term))
)
query = query.order_by(Person.display_name).limit(limit)
people = query.all()
return [_person_to_dict(p) for p in people]
@mcp.tool()
async def delete_person(identifier: str) -> dict:
"""
Delete a person by their identifier.
This permanently removes the person and all associated data.
Observations about this person (with subject "person:<identifier>") will remain.
Args:
identifier: The person's unique identifier
Returns:
Confirmation of deletion
"""
logger.info(f"MCP: Deleting person: {identifier}")
with make_session() as session:
person = session.query(Person).filter(Person.identifier == identifier).first()
if not person:
raise ValueError(f"Person with identifier '{identifier}' not found")
display_name = person.display_name
session.delete(person)
session.commit()
return {
"deleted": True,
"identifier": identifier,
"display_name": display_name,
}

View File

@ -16,6 +16,7 @@ SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
DISCORD_ROOT = "memory.workers.tasks.discord"
BACKUP_ROOT = "memory.workers.tasks.backup"
GITHUB_ROOT = "memory.workers.tasks.github"
PEOPLE_ROOT = "memory.workers.tasks.people"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
@ -67,6 +68,10 @@ SYNC_GITHUB_REPO = f"{GITHUB_ROOT}.sync_github_repo"
SYNC_ALL_GITHUB_REPOS = f"{GITHUB_ROOT}.sync_all_github_repos"
SYNC_GITHUB_ITEM = f"{GITHUB_ROOT}.sync_github_item"
# People tasks
SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE
@ -123,6 +128,7 @@ app.conf.update(
},
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
},
beat_schedule={
"sync-github-repos-hourly": {

View File

@ -17,6 +17,7 @@ from memory.common.db.models.source_items import (
BookSection,
ForumPost,
GithubItem,
GithubPRData,
GitCommit,
Photo,
MiscDoc,
@ -46,6 +47,10 @@ from memory.common.db.models.observations import (
BeliefCluster,
ConversationMetrics,
)
from memory.common.db.models.people import (
Person,
PersonPayload,
)
from memory.common.db.models.sources import (
Book,
ArticleFeed,
@ -77,6 +82,7 @@ Payload = (
| ForumPostPayload
| EmailAttachmentPayload
| MailMessagePayload
| PersonPayload
)
__all__ = [
@ -95,6 +101,7 @@ __all__ = [
"BookSection",
"ForumPost",
"GithubItem",
"GithubPRData",
"GitCommit",
"Photo",
"MiscDoc",
@ -105,6 +112,9 @@ __all__ = [
"ObservationPattern",
"BeliefCluster",
"ConversationMetrics",
# People
"Person",
"PersonPayload",
# Sources
"Book",
"ArticleFeed",

View File

@ -0,0 +1,95 @@
"""
Database models for tracking people.
"""
from typing import Annotated, Sequence, cast
from sqlalchemy import (
ARRAY,
BigInteger,
Column,
ForeignKey,
Index,
Text,
)
from sqlalchemy.dialects.postgresql import JSONB
import memory.common.extract as extract
from memory.common.db.models.source_item import (
SourceItem,
SourceItemPayload,
)
class PersonPayload(SourceItemPayload):
identifier: Annotated[str, "Unique identifier/slug for the person"]
display_name: Annotated[str, "Display name of the person"]
aliases: Annotated[list[str], "Alternative names/handles for the person"]
contact_info: Annotated[dict, "Contact information (email, phone, etc.)"]
class Person(SourceItem):
"""A person you know or want to track."""
__tablename__ = "people"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
identifier = Column(Text, unique=True, nullable=False, index=True)
display_name = Column(Text, nullable=False)
aliases = Column(ARRAY(Text), server_default="{}", nullable=False)
contact_info = Column(JSONB, server_default="{}", nullable=False)
__mapper_args__ = {
"polymorphic_identity": "person",
}
__table_args__ = (
Index("person_identifier_idx", "identifier"),
Index("person_display_name_idx", "display_name"),
Index("person_aliases_idx", "aliases", postgresql_using="gin"),
)
def as_payload(self) -> PersonPayload:
return PersonPayload(
**super().as_payload(),
identifier=cast(str, self.identifier),
display_name=cast(str, self.display_name),
aliases=cast(list[str], self.aliases) or [],
contact_info=cast(dict, self.contact_info) or {},
)
@property
def display_contents(self) -> dict:
return {
"identifier": self.identifier,
"display_name": self.display_name,
"aliases": self.aliases,
"contact_info": self.contact_info,
"notes": self.content,
"tags": self.tags,
}
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
"""Create searchable chunks from person data."""
parts = [f"# {self.display_name}"]
if self.aliases:
aliases_str = ", ".join(cast(list[str], self.aliases))
parts.append(f"Also known as: {aliases_str}")
if self.tags:
tags_str = ", ".join(cast(list[str], self.tags))
parts.append(f"Tags: {tags_str}")
if self.content:
parts.append(f"\n{self.content}")
text = "\n".join(parts)
return extract.extract_text(text, modality="person")
@classmethod
def get_collections(cls) -> list[str]:
return ["person"]

View File

@ -9,15 +9,19 @@ from collections.abc import Collection
from typing import Any, Annotated, Sequence, cast
from PIL import Image
import zlib
from sqlalchemy import (
ARRAY,
BigInteger,
Boolean,
CheckConstraint,
Column,
DateTime,
ForeignKey,
Index,
Integer,
LargeBinary,
Numeric,
Text,
func,
@ -30,6 +34,7 @@ import memory.common.extract as extract
import memory.common.summarizer as summarizer
import memory.common.formatters.observation as observation
from memory.common.db.models.base import Base
from memory.common.db.models.source_item import (
SourceItem,
SourceItemPayload,
@ -858,6 +863,14 @@ class GithubItem(SourceItem):
milestone = Column(Text, nullable=True)
comment_count = Column(Integer, nullable=True)
# Relationship to PR-specific data
pr_data = relationship(
"GithubPRData",
back_populates="github_item",
uselist=False,
cascade="all, delete-orphan",
)
__mapper_args__ = {
"polymorphic_identity": "github_item",
}
@ -902,6 +915,59 @@ class GithubItem(SourceItem):
return []
class GithubPRData(Base):
"""PR-specific data linked to GithubItem. Not a SourceItem - not indexed separately."""
__tablename__ = "github_pr_data"
id = Column(BigInteger, primary_key=True)
github_item_id = Column(
BigInteger,
ForeignKey("github_item.id", ondelete="CASCADE"),
unique=True,
nullable=False,
index=True,
)
# Diff (compressed with zlib)
diff_compressed = Column(LargeBinary, nullable=True)
# File changes as structured data
# [{filename, status, additions, deletions, patch?}]
files = Column(JSONB, nullable=True)
# Stats
additions = Column(Integer, nullable=True)
deletions = Column(Integer, nullable=True)
changed_files_count = Column(Integer, nullable=True)
# Reviews (structured)
# [{user, state, body, submitted_at}]
reviews = Column(JSONB, nullable=True)
# Review comments (line-by-line code comments)
# [{user, body, path, line, diff_hunk, created_at}]
review_comments = Column(JSONB, nullable=True)
# Relationship back to GithubItem
github_item = relationship("GithubItem", back_populates="pr_data")
@property
def diff(self) -> str | None:
"""Decompress and return the full diff text."""
if self.diff_compressed:
return zlib.decompress(self.diff_compressed).decode("utf-8")
return None
@diff.setter
def diff(self, value: str | None) -> None:
"""Compress and store the diff text."""
if value:
self.diff_compressed = zlib.compress(value.encode("utf-8"))
else:
self.diff_compressed = None
class NotePayload(SourceItemPayload):
note_type: Annotated[str | None, "Category of the note"]
subject: Annotated[str | None, "What the note is about"]

View File

@ -42,6 +42,51 @@ class GithubComment(TypedDict):
updated_at: str
class GithubReviewComment(TypedDict):
"""A line-by-line code review comment on a PR."""
id: int
user: str
body: str
path: str
line: int | None
side: str # "LEFT" or "RIGHT"
diff_hunk: str
created_at: str
class GithubReview(TypedDict):
"""A PR review (approval, request changes, etc.)."""
id: int
user: str
state: str # "approved", "changes_requested", "commented", "dismissed"
body: str | None
submitted_at: str
class GithubFileChange(TypedDict):
"""A file changed in a PR."""
filename: str
status: str # "added", "modified", "removed", "renamed"
additions: int
deletions: int
patch: str | None # Diff patch for this file
class GithubPRDataDict(TypedDict):
"""PR-specific data for storage in GithubPRData model."""
diff: str | None # Full diff text
files: list[GithubFileChange]
additions: int
deletions: int
changed_files_count: int
reviews: list[GithubReview]
review_comments: list[GithubReviewComment]
class GithubIssueData(TypedDict):
"""Parsed issue/PR data ready for storage."""
@ -60,9 +105,11 @@ class GithubIssueData(TypedDict):
github_updated_at: datetime
comment_count: int
comments: list[GithubComment]
diff_summary: str | None # PRs only
diff_summary: str | None # PRs only (truncated, for backward compat)
project_fields: dict[str, Any] | None
content_hash: str
# PR-specific extended data (None for issues)
pr_data: GithubPRDataDict | None
def parse_github_date(date_str: str | None) -> datetime | None:
@ -267,6 +314,145 @@ class GithubClient:
return comments
def fetch_review_comments(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[GithubReviewComment]:
"""Fetch all line-by-line review comments for a PR."""
comments: list[GithubReviewComment] = []
page = 1
while True:
response = self.session.get(
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/comments",
params={"page": page, "per_page": 100},
timeout=30,
)
response.raise_for_status()
self._handle_rate_limit(response)
page_comments = response.json()
if not page_comments:
break
comments.extend(
[
GithubReviewComment(
id=c["id"],
user=c["user"]["login"] if c.get("user") else "ghost",
body=c.get("body", ""),
path=c.get("path", ""),
line=c.get("line"),
side=c.get("side", "RIGHT"),
diff_hunk=c.get("diff_hunk", ""),
created_at=c["created_at"],
)
for c in page_comments
]
)
page += 1
return comments
def fetch_reviews(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[GithubReview]:
"""Fetch all reviews (approvals, change requests) for a PR."""
reviews: list[GithubReview] = []
page = 1
while True:
response = self.session.get(
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/reviews",
params={"page": page, "per_page": 100},
timeout=30,
)
response.raise_for_status()
self._handle_rate_limit(response)
page_reviews = response.json()
if not page_reviews:
break
reviews.extend(
[
GithubReview(
id=r["id"],
user=r["user"]["login"] if r.get("user") else "ghost",
state=r.get("state", "COMMENTED").lower(),
body=r.get("body"),
submitted_at=r.get("submitted_at", ""),
)
for r in page_reviews
]
)
page += 1
return reviews
def fetch_pr_files(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[GithubFileChange]:
"""Fetch list of files changed in a PR with patches."""
files: list[GithubFileChange] = []
page = 1
while True:
response = self.session.get(
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/files",
params={"page": page, "per_page": 100},
timeout=30,
)
response.raise_for_status()
self._handle_rate_limit(response)
page_files = response.json()
if not page_files:
break
files.extend(
[
GithubFileChange(
filename=f["filename"],
status=f.get("status", "modified"),
additions=f.get("additions", 0),
deletions=f.get("deletions", 0),
patch=f.get("patch"), # May be None for binary files
)
for f in page_files
]
)
page += 1
return files
def fetch_pr_diff(
self,
owner: str,
repo: str,
pr_number: int,
) -> str | None:
"""Fetch the full diff for a PR (not truncated)."""
try:
response = self.session.get(
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}",
headers={"Accept": "application/vnd.github.diff"},
timeout=60, # Longer timeout for large diffs
)
if response.ok:
return response.text
except Exception as e:
logger.warning(f"Failed to fetch PR diff: {e}")
return None
def fetch_project_fields(
self,
owner: str,
@ -490,28 +676,44 @@ class GithubClient:
diff_summary=None,
project_fields=None, # Fetched separately if enabled
content_hash=compute_content_hash(body, comments),
pr_data=None, # Issues don't have PR data
)
def _parse_pr(
self, owner: str, repo: str, pr: dict[str, Any]
) -> GithubIssueData:
"""Parse raw PR data into structured format."""
comments = self.fetch_comments(owner, repo, pr["number"])
pr_number = pr["number"]
comments = self.fetch_comments(owner, repo, pr_number)
body = pr.get("body") or ""
# Get diff summary (truncated)
diff_summary = None
if diff_url := pr.get("diff_url"):
try:
diff_response = self.session.get(diff_url, timeout=30)
if diff_response.ok:
diff_summary = diff_response.text[:5000] # Truncate large diffs
except Exception as e:
logger.warning(f"Failed to fetch diff: {e}")
# Fetch PR-specific data
review_comments = self.fetch_review_comments(owner, repo, pr_number)
reviews = self.fetch_reviews(owner, repo, pr_number)
files = self.fetch_pr_files(owner, repo, pr_number)
full_diff = self.fetch_pr_diff(owner, repo, pr_number)
# Calculate stats from files
additions = sum(f["additions"] for f in files)
deletions = sum(f["deletions"] for f in files)
# Get diff summary (truncated, for backward compatibility)
diff_summary = full_diff[:5000] if full_diff else None
# Build PR data dict
pr_data = GithubPRDataDict(
diff=full_diff,
files=files,
additions=additions,
deletions=deletions,
changed_files_count=len(files),
reviews=reviews,
review_comments=review_comments,
)
return GithubIssueData(
kind="pr",
number=pr["number"],
number=pr_number,
title=pr["title"],
body=body,
state=pr["state"],
@ -528,4 +730,5 @@ class GithubClient:
diff_summary=diff_summary,
project_fields=None, # Fetched separately if enabled
content_hash=compute_content_hash(body, comments),
pr_data=pr_data,
)

View File

@ -12,12 +12,13 @@ from memory.common.celery_app import (
SYNC_GITHUB_ITEM,
)
from memory.common.db.connection import make_session
from memory.common.db.models import GithubItem
from memory.common.db.models import GithubItem, GithubPRData
from memory.common.db.models.sources import GithubAccount, GithubRepo
from memory.parsers.github import (
GithubClient,
GithubCredentials,
GithubIssueData,
GithubPRDataDict,
)
from memory.workers.tasks.content_processing import (
create_content_hash,
@ -32,11 +33,42 @@ logger = logging.getLogger(__name__)
def _build_content(issue_data: GithubIssueData) -> str:
"""Build searchable content from issue/PR data."""
content_parts = [f"# {issue_data['title']}", issue_data["body"]]
# Add regular comments
for comment in issue_data["comments"]:
content_parts.append(f"\n---\n**{comment['author']}**: {comment['body']}")
# Add review comments for PRs (makes them searchable)
pr_data = issue_data.get("pr_data")
if pr_data and pr_data.get("review_comments"):
content_parts.append("\n---\n## Code Review Comments\n")
for rc in pr_data["review_comments"]:
content_parts.append(
f"**{rc['user']}** on `{rc['path']}`: {rc['body']}"
)
return "\n\n".join(content_parts)
def _create_pr_data(issue_data: GithubIssueData) -> GithubPRData | None:
"""Create GithubPRData from PR-specific data if available."""
pr_data_dict = issue_data.get("pr_data")
if not pr_data_dict:
return None
pr_data = GithubPRData(
additions=pr_data_dict.get("additions"),
deletions=pr_data_dict.get("deletions"),
changed_files_count=pr_data_dict.get("changed_files_count"),
files=pr_data_dict.get("files"),
reviews=pr_data_dict.get("reviews"),
review_comments=pr_data_dict.get("review_comments"),
)
# Use the setter to compress the diff
pr_data.diff = pr_data_dict.get("diff")
return pr_data
def _create_github_item(
repo: GithubRepo,
issue_data: GithubIssueData,
@ -57,7 +89,7 @@ def _create_github_item(
repo_tags = cast(list[str], repo.tags) or []
return GithubItem(
github_item = GithubItem(
modality="github",
sha256=create_content_hash(content),
content=content,
@ -86,6 +118,12 @@ def _create_github_item(
mime_type="text/markdown",
)
# Create PR data if this is a PR
if issue_data["kind"] == "pr":
github_item.pr_data = _create_pr_data(issue_data)
return github_item
def _needs_reindex(existing: GithubItem, new_data: GithubIssueData) -> bool:
"""Check if an existing item needs reindexing based on content changes."""
@ -165,6 +203,23 @@ def _update_existing_item(
repo_tags = cast(list[str], repo.tags) or []
existing.tags = repo_tags + issue_data["labels"] # type: ignore
# Update PR data if this is a PR
if issue_data["kind"] == "pr":
pr_data_dict = issue_data.get("pr_data")
if pr_data_dict:
if existing.pr_data:
# Update existing pr_data
existing.pr_data.additions = pr_data_dict.get("additions")
existing.pr_data.deletions = pr_data_dict.get("deletions")
existing.pr_data.changed_files_count = pr_data_dict.get("changed_files_count")
existing.pr_data.files = pr_data_dict.get("files")
existing.pr_data.reviews = pr_data_dict.get("reviews")
existing.pr_data.review_comments = pr_data_dict.get("review_comments")
existing.pr_data.diff = pr_data_dict.get("diff")
else:
# Create new pr_data
existing.pr_data = _create_pr_data(issue_data)
session.flush()
# Re-embed and push to Qdrant
@ -193,6 +248,8 @@ def _serialize_issue_data(data: GithubIssueData) -> dict[str, Any]:
}
for c in data["comments"]
],
# pr_data is already JSON-serializable (TypedDict)
"pr_data": data.get("pr_data"),
}
@ -200,6 +257,11 @@ def _deserialize_issue_data(data: dict[str, Any]) -> GithubIssueData:
"""Deserialize issue data from Celery task."""
from memory.parsers.github import parse_github_date
# Reconstruct pr_data if present
pr_data: GithubPRDataDict | None = None
if data.get("pr_data"):
pr_data = cast(GithubPRDataDict, data["pr_data"])
return GithubIssueData(
kind=data["kind"],
number=data["number"],
@ -219,6 +281,7 @@ def _deserialize_issue_data(data: dict[str, Any]) -> GithubIssueData:
diff_summary=data.get("diff_summary"),
project_fields=data.get("project_fields"),
content_hash=data["content_hash"],
pr_data=pr_data,
)

View File

@ -0,0 +1,145 @@
"""
Celery tasks for tracking people.
"""
import logging
from memory.common.db.connection import make_session
from memory.common.db.models import Person
from memory.common.celery_app import app, SYNC_PERSON, UPDATE_PERSON
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
create_task_result,
process_content_item,
safe_task_execution,
)
logger = logging.getLogger(__name__)
def _deep_merge(base: dict, updates: dict) -> dict:
"""Deep merge two dictionaries, with updates taking precedence."""
result = dict(base)
for key, value in updates.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
@app.task(name=SYNC_PERSON)
@safe_task_execution
def sync_person(
identifier: str,
display_name: str,
aliases: list[str] | None = None,
contact_info: dict | None = None,
tags: list[str] | None = None,
notes: str | None = None,
):
"""
Create or update a person in the knowledge base.
Args:
identifier: Unique slug for the person
display_name: Human-readable name
aliases: Alternative names/handles
contact_info: Contact information dict
tags: Categorization tags
notes: Free-form notes about the person
"""
logger.info(f"Syncing person: {identifier}")
# Create hash from identifier for deduplication
sha256 = create_content_hash(f"person:{identifier}")
with make_session() as session:
# Check if person already exists by identifier
existing = session.query(Person).filter(Person.identifier == identifier).first()
if existing:
logger.info(f"Person already exists: {identifier}")
return create_task_result(existing, "already_exists")
# Also check by sha256 (defensive)
existing_by_hash = check_content_exists(session, Person, sha256=sha256)
if existing_by_hash:
logger.info(f"Person already exists (by hash): {identifier}")
return create_task_result(existing_by_hash, "already_exists")
person = Person(
identifier=identifier,
display_name=display_name,
aliases=aliases or [],
contact_info=contact_info or {},
tags=tags or [],
content=notes,
modality="person",
mime_type="text/plain",
sha256=sha256,
size=len(notes or ""),
)
return process_content_item(person, session)
@app.task(name=UPDATE_PERSON)
@safe_task_execution
def update_person(
identifier: str,
display_name: str | None = None,
aliases: list[str] | None = None,
contact_info: dict | None = None,
tags: list[str] | None = None,
notes: str | None = None,
replace_notes: bool = False,
):
"""
Update a person with merge semantics.
Merge behavior:
- display_name: Replaces if provided
- aliases: Union with existing
- contact_info: Deep merge with existing
- tags: Union with existing
- notes: Append to existing (or replace if replace_notes=True)
"""
logger.info(f"Updating person: {identifier}")
with make_session() as session:
person = session.query(Person).filter(Person.identifier == identifier).first()
if not person:
logger.warning(f"Person not found: {identifier}")
return {"status": "not_found", "identifier": identifier}
if display_name is not None:
person.display_name = display_name
if aliases is not None:
existing_aliases = set(person.aliases or [])
new_aliases = existing_aliases | set(aliases)
person.aliases = list(new_aliases)
if contact_info is not None:
existing_contact = dict(person.contact_info or {})
person.contact_info = _deep_merge(existing_contact, contact_info)
if tags is not None:
existing_tags = set(person.tags or [])
new_tags = existing_tags | set(tags)
person.tags = list(new_tags)
if notes is not None:
if replace_notes or not person.content:
person.content = notes
else:
person.content = f"{person.content}\n\n---\n\n{notes}"
# Update hash based on new content
person.sha256 = create_content_hash(f"person:{identifier}")
person.size = len(person.content or "")
person.embed_status = "RAW" # Re-embed with updated content
return process_content_item(person, session)

View File

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,482 @@
"""Tests for People MCP tools."""
import sys
import pytest
from unittest.mock import AsyncMock, Mock, MagicMock, patch
# Mock the mcp module and all its submodules before importing anything that uses it
_mock_mcp = MagicMock()
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
sys.modules["mcp"] = _mock_mcp
sys.modules["mcp.server"] = MagicMock()
sys.modules["mcp.server.auth"] = MagicMock()
sys.modules["mcp.server.auth.handlers"] = MagicMock()
sys.modules["mcp.server.auth.handlers.authorize"] = MagicMock()
sys.modules["mcp.server.auth.handlers.token"] = MagicMock()
sys.modules["mcp.server.auth.provider"] = MagicMock()
sys.modules["mcp.server.fastmcp"] = MagicMock()
sys.modules["mcp.server.fastmcp.server"] = MagicMock()
# Also mock the memory.api.MCP.base module to avoid MCP imports
_mock_base = MagicMock()
_mock_base.mcp = MagicMock()
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
sys.modules["memory.api.MCP.base"] = _mock_base
from memory.common.db.models import Person
from memory.common.db import connection as db_connection
from memory.workers.tasks.content_processing import create_content_hash
@pytest.fixture(autouse=True)
def reset_db_cache():
"""Reset the cached database engine between tests."""
db_connection._engine = None
db_connection._session_factory = None
db_connection._scoped_session = None
yield
db_connection._engine = None
db_connection._session_factory = None
db_connection._scoped_session = None
@pytest.fixture
def sample_people(db_session):
"""Create sample people for testing."""
people = [
Person(
identifier="alice_chen",
display_name="Alice Chen",
aliases=["@alice_c", "alice.chen@work.com"],
contact_info={"email": "alice@example.com", "phone": "555-1234"},
tags=["work", "engineering"],
content="Tech lead on Platform team. Very thorough in code reviews.",
modality="person",
sha256=create_content_hash("person:alice_chen"),
size=100,
),
Person(
identifier="bob_smith",
display_name="Bob Smith",
aliases=["@bobsmith"],
contact_info={"email": "bob@example.com"},
tags=["work", "design"],
content="UX designer. Prefers visual communication.",
modality="person",
sha256=create_content_hash("person:bob_smith"),
size=50,
),
Person(
identifier="charlie_jones",
display_name="Charlie Jones",
aliases=[],
contact_info={"twitter": "@charlie_j"},
tags=["friend", "climbing"],
content="Met at climbing gym. Very reliable.",
modality="person",
sha256=create_content_hash("person:charlie_jones"),
size=30,
),
]
for person in people:
db_session.add(person)
db_session.commit()
for person in people:
db_session.refresh(person)
return people
# =============================================================================
# Tests for add_person
# =============================================================================
@pytest.mark.asyncio
async def test_add_person_success(db_session):
"""Test adding a new person."""
from memory.api.MCP.people import add_person
mock_task = Mock()
mock_task.id = "task-123"
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with patch("memory.api.MCP.people.celery_app") as mock_celery:
mock_celery.send_task.return_value = mock_task
result = await add_person(
identifier="new_person",
display_name="New Person",
aliases=["@newperson"],
contact_info={"email": "new@example.com"},
tags=["friend"],
notes="A new friend.",
)
assert result["status"] == "queued"
assert result["task_id"] == "task-123"
assert result["identifier"] == "new_person"
# Verify Celery task was called
mock_celery.send_task.assert_called_once()
call_kwargs = mock_celery.send_task.call_args[1]
assert call_kwargs["kwargs"]["identifier"] == "new_person"
assert call_kwargs["kwargs"]["display_name"] == "New Person"
@pytest.mark.asyncio
async def test_add_person_already_exists(db_session, sample_people):
"""Test adding a person that already exists."""
from memory.api.MCP.people import add_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="already exists"):
await add_person(
identifier="alice_chen", # Already exists
display_name="Alice Chen Duplicate",
)
@pytest.mark.asyncio
async def test_add_person_minimal(db_session):
"""Test adding a person with minimal data."""
from memory.api.MCP.people import add_person
mock_task = Mock()
mock_task.id = "task-456"
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with patch("memory.api.MCP.people.celery_app") as mock_celery:
mock_celery.send_task.return_value = mock_task
result = await add_person(
identifier="minimal_person",
display_name="Minimal Person",
)
assert result["status"] == "queued"
assert result["identifier"] == "minimal_person"
# =============================================================================
# Tests for update_person_info
# =============================================================================
@pytest.mark.asyncio
async def test_update_person_info_success(db_session, sample_people):
"""Test updating a person's info."""
from memory.api.MCP.people import update_person_info
mock_task = Mock()
mock_task.id = "task-789"
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with patch("memory.api.MCP.people.celery_app") as mock_celery:
mock_celery.send_task.return_value = mock_task
result = await update_person_info(
identifier="alice_chen",
display_name="Alice M. Chen",
notes="Added middle initial",
)
assert result["status"] == "queued"
assert result["task_id"] == "task-789"
assert result["identifier"] == "alice_chen"
@pytest.mark.asyncio
async def test_update_person_info_not_found(db_session, sample_people):
"""Test updating a person that doesn't exist."""
from memory.api.MCP.people import update_person_info
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"):
await update_person_info(
identifier="nonexistent_person",
display_name="New Name",
)
@pytest.mark.asyncio
async def test_update_person_info_with_merge_params(db_session, sample_people):
"""Test that update passes all merge parameters."""
from memory.api.MCP.people import update_person_info
mock_task = Mock()
mock_task.id = "task-merge"
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with patch("memory.api.MCP.people.celery_app") as mock_celery:
mock_celery.send_task.return_value = mock_task
result = await update_person_info(
identifier="alice_chen",
aliases=["@alice_new"],
contact_info={"slack": "@alice"},
tags=["leadership"],
notes="Promoted to senior",
replace_notes=False,
)
call_kwargs = mock_celery.send_task.call_args[1]["kwargs"]
assert call_kwargs["aliases"] == ["@alice_new"]
assert call_kwargs["contact_info"] == {"slack": "@alice"}
assert call_kwargs["tags"] == ["leadership"]
assert call_kwargs["notes"] == "Promoted to senior"
assert call_kwargs["replace_notes"] is False
# =============================================================================
# Tests for get_person
# =============================================================================
@pytest.mark.asyncio
async def test_get_person_found(db_session, sample_people):
"""Test getting a person that exists."""
from memory.api.MCP.people import get_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await get_person(identifier="alice_chen")
assert result is not None
assert result["identifier"] == "alice_chen"
assert result["display_name"] == "Alice Chen"
assert result["aliases"] == ["@alice_c", "alice.chen@work.com"]
assert result["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
assert result["tags"] == ["work", "engineering"]
assert "Tech lead" in result["notes"]
@pytest.mark.asyncio
async def test_get_person_not_found(db_session, sample_people):
"""Test getting a person that doesn't exist."""
from memory.api.MCP.people import get_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await get_person(identifier="nonexistent_person")
assert result is None
# =============================================================================
# Tests for list_people
# =============================================================================
@pytest.mark.asyncio
async def test_list_people_no_filters(db_session, sample_people):
"""Test listing all people without filters."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people()
assert len(results) == 3
# Should be ordered by display_name
assert results[0]["display_name"] == "Alice Chen"
assert results[1]["display_name"] == "Bob Smith"
assert results[2]["display_name"] == "Charlie Jones"
@pytest.mark.asyncio
async def test_list_people_filter_by_tags(db_session, sample_people):
"""Test filtering by tags."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=["work"])
assert len(results) == 2
assert all("work" in r["tags"] for r in results)
@pytest.mark.asyncio
async def test_list_people_filter_by_search(db_session, sample_people):
"""Test filtering by search term."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search="alice")
assert len(results) == 1
assert results[0]["identifier"] == "alice_chen"
@pytest.mark.asyncio
async def test_list_people_search_in_notes(db_session, sample_people):
"""Test that search works on notes content."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search="climbing")
assert len(results) == 1
assert results[0]["identifier"] == "charlie_jones"
@pytest.mark.asyncio
async def test_list_people_limit(db_session, sample_people):
"""Test limiting results."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(limit=1)
assert len(results) == 1
@pytest.mark.asyncio
async def test_list_people_limit_max_enforced(db_session, sample_people):
"""Test that limit is capped at 200."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
# Request 500 but should be capped at 200
results = await list_people(limit=500)
# We only have 3 people, but the limit logic should cap at 200
assert len(results) <= 200
@pytest.mark.asyncio
async def test_list_people_combined_filters(db_session, sample_people):
"""Test combining tag and search filters."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=["work"], search="chen")
assert len(results) == 1
assert results[0]["identifier"] == "alice_chen"
# =============================================================================
# Tests for delete_person
# =============================================================================
@pytest.mark.asyncio
async def test_delete_person_success(db_session, sample_people):
"""Test deleting a person."""
from memory.api.MCP.people import delete_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await delete_person(identifier="bob_smith")
assert result["deleted"] is True
assert result["identifier"] == "bob_smith"
assert result["display_name"] == "Bob Smith"
# Verify person was deleted
remaining = db_session.query(Person).filter_by(identifier="bob_smith").first()
assert remaining is None
@pytest.mark.asyncio
async def test_delete_person_not_found(db_session, sample_people):
"""Test deleting a person that doesn't exist."""
from memory.api.MCP.people import delete_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"):
await delete_person(identifier="nonexistent_person")
# =============================================================================
# Tests for _person_to_dict helper
# =============================================================================
def test_person_to_dict(sample_people):
"""Test the _person_to_dict helper function."""
from memory.api.MCP.people import _person_to_dict
person = sample_people[0]
result = _person_to_dict(person)
assert result["identifier"] == "alice_chen"
assert result["display_name"] == "Alice Chen"
assert result["aliases"] == ["@alice_c", "alice.chen@work.com"]
assert result["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
assert result["tags"] == ["work", "engineering"]
assert result["notes"] == "Tech lead on Platform team. Very thorough in code reviews."
assert result["created_at"] is not None
def test_person_to_dict_empty_fields(db_session):
"""Test _person_to_dict with empty optional fields."""
from memory.api.MCP.people import _person_to_dict
person = Person(
identifier="empty_person",
display_name="Empty Person",
aliases=[],
contact_info={},
tags=[],
content=None,
modality="person",
sha256=create_content_hash("person:empty_person"),
size=0,
)
result = _person_to_dict(person)
assert result["identifier"] == "empty_person"
assert result["aliases"] == []
assert result["contact_info"] == {}
assert result["tags"] == []
assert result["notes"] is None
# =============================================================================
# Parametrized tests
# =============================================================================
@pytest.mark.asyncio
@pytest.mark.parametrize(
"tag,expected_count",
[
("work", 2),
("engineering", 1),
("design", 1),
("friend", 1),
("climbing", 1),
("nonexistent", 0),
],
)
async def test_list_people_various_tags(db_session, sample_people, tag, expected_count):
"""Test filtering by various tags."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=[tag])
assert len(results) == expected_count
@pytest.mark.asyncio
@pytest.mark.parametrize(
"search_term,expected_identifiers",
[
("alice", ["alice_chen"]),
("bob", ["bob_smith"]),
("smith", ["bob_smith"]),
("chen", ["alice_chen"]),
("jones", ["charlie_jones"]),
("example.com", []), # Not searching in contact_info
("UX", ["bob_smith"]), # Case insensitive search in notes
],
)
async def test_list_people_various_searches(
db_session, sample_people, search_term, expected_identifiers
):
"""Test search with various terms."""
from memory.api.MCP.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search=search_term)
result_identifiers = [r["identifier"] for r in results]
assert result_identifiers == expected_identifiers

View File

View File

@ -0,0 +1,249 @@
"""Tests for the Person model."""
import pytest
from memory.common.db.models import Person
from memory.workers.tasks.content_processing import create_content_hash
@pytest.fixture
def person_data():
"""Standard person test data."""
return {
"identifier": "alice_chen",
"display_name": "Alice Chen",
"aliases": ["@alice_c", "alice.chen@work.com"],
"contact_info": {"email": "alice@example.com", "phone": "555-1234"},
"tags": ["work", "engineering"],
"content": "Tech lead on Platform team. Very thorough in code reviews.",
"modality": "person",
"mime_type": "text/plain",
}
@pytest.fixture
def minimal_person_data():
"""Minimal person test data."""
return {
"identifier": "bob_smith",
"display_name": "Bob Smith",
"modality": "person",
}
def test_person_creation(person_data):
"""Test creating a Person with all fields."""
sha256 = create_content_hash(f"person:{person_data['identifier']}")
person = Person(**person_data, sha256=sha256, size=100)
assert person.identifier == "alice_chen"
assert person.display_name == "Alice Chen"
assert person.aliases == ["@alice_c", "alice.chen@work.com"]
assert person.contact_info == {"email": "alice@example.com", "phone": "555-1234"}
assert person.tags == ["work", "engineering"]
assert person.content == "Tech lead on Platform team. Very thorough in code reviews."
assert person.modality == "person"
def test_person_creation_minimal(minimal_person_data):
"""Test creating a Person with minimal fields."""
sha256 = create_content_hash(f"person:{minimal_person_data['identifier']}")
person = Person(**minimal_person_data, sha256=sha256, size=0)
assert person.identifier == "bob_smith"
assert person.display_name == "Bob Smith"
assert person.aliases == [] or person.aliases is None
assert person.contact_info == {} or person.contact_info is None
assert person.tags == [] or person.tags is None
assert person.content is None
def test_person_as_payload(person_data):
"""Test the as_payload method."""
sha256 = create_content_hash(f"person:{person_data['identifier']}")
person = Person(**person_data, sha256=sha256, size=100)
payload = person.as_payload()
assert payload["identifier"] == "alice_chen"
assert payload["display_name"] == "Alice Chen"
assert payload["aliases"] == ["@alice_c", "alice.chen@work.com"]
assert payload["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
def test_person_display_contents(person_data):
"""Test the display_contents property."""
sha256 = create_content_hash(f"person:{person_data['identifier']}")
person = Person(**person_data, sha256=sha256, size=100)
contents = person.display_contents
assert contents["identifier"] == "alice_chen"
assert contents["display_name"] == "Alice Chen"
assert contents["aliases"] == ["@alice_c", "alice.chen@work.com"]
assert contents["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
assert contents["notes"] == "Tech lead on Platform team. Very thorough in code reviews."
assert contents["tags"] == ["work", "engineering"]
def test_person_chunk_contents(person_data):
"""Test the _chunk_contents method generates searchable chunks."""
sha256 = create_content_hash(f"person:{person_data['identifier']}")
person = Person(**person_data, sha256=sha256, size=100)
chunks = person._chunk_contents()
assert len(chunks) > 0
chunk_text = chunks[0].data[0]
# Should include display name
assert "Alice Chen" in chunk_text
# Should include aliases
assert "@alice_c" in chunk_text
# Should include tags
assert "work" in chunk_text
# Should include notes/content
assert "Tech lead" in chunk_text
def test_person_chunk_contents_minimal(minimal_person_data):
"""Test _chunk_contents with minimal data."""
sha256 = create_content_hash(f"person:{minimal_person_data['identifier']}")
person = Person(**minimal_person_data, sha256=sha256, size=0)
chunks = person._chunk_contents()
assert len(chunks) > 0
chunk_text = chunks[0].data[0]
assert "Bob Smith" in chunk_text
def test_person_get_collections():
"""Test that Person returns correct collections."""
collections = Person.get_collections()
assert collections == ["person"]
def test_person_polymorphic_identity():
"""Test that Person has correct polymorphic identity."""
assert Person.__mapper_args__["polymorphic_identity"] == "person"
@pytest.mark.parametrize(
"identifier,display_name,aliases,tags",
[
("john_doe", "John Doe", [], []),
("jane_smith", "Jane Smith", ["@jane"], ["friend"]),
("bob_jones", "Bob Jones", ["@bob", "bobby"], ["work", "climbing", "london"]),
(
"alice_wong",
"Alice Wong",
["@alice", "alice@work.com", "Alice W."],
["family", "close"],
),
],
)
def test_person_various_configurations(identifier, display_name, aliases, tags):
"""Test Person creation with various configurations."""
sha256 = create_content_hash(f"person:{identifier}")
person = Person(
identifier=identifier,
display_name=display_name,
aliases=aliases,
tags=tags,
modality="person",
sha256=sha256,
size=0,
)
assert person.identifier == identifier
assert person.display_name == display_name
assert person.aliases == aliases
assert person.tags == tags
def test_person_contact_info_flexible():
"""Test that contact_info can hold various structures."""
contact_info = {
"email": "test@example.com",
"phone": "+1-555-1234",
"twitter": "@testuser",
"linkedin": "linkedin.com/in/testuser",
"address": {
"street": "123 Main St",
"city": "San Francisco",
"country": "USA",
},
}
sha256 = create_content_hash("person:test_user")
person = Person(
identifier="test_user",
display_name="Test User",
contact_info=contact_info,
modality="person",
sha256=sha256,
size=0,
)
assert person.contact_info == contact_info
assert person.contact_info["address"]["city"] == "San Francisco"
def test_person_in_db(db_session, qdrant):
"""Test Person persistence in database."""
sha256 = create_content_hash("person:db_test_user")
person = Person(
identifier="db_test_user",
display_name="DB Test User",
aliases=["@dbtest"],
contact_info={"email": "dbtest@example.com"},
tags=["test"],
content="Test notes",
modality="person",
mime_type="text/plain",
sha256=sha256,
size=10,
)
db_session.add(person)
db_session.commit()
# Query it back
retrieved = db_session.query(Person).filter_by(identifier="db_test_user").first()
assert retrieved is not None
assert retrieved.display_name == "DB Test User"
assert retrieved.aliases == ["@dbtest"]
assert retrieved.contact_info == {"email": "dbtest@example.com"}
assert retrieved.tags == ["test"]
assert retrieved.content == "Test notes"
def test_person_unique_identifier(db_session, qdrant):
"""Test that identifier must be unique."""
sha256 = create_content_hash("person:unique_test")
person1 = Person(
identifier="unique_test",
display_name="Person 1",
modality="person",
sha256=sha256,
size=0,
)
db_session.add(person1)
db_session.commit()
# Try to add another with same identifier
person2 = Person(
identifier="unique_test",
display_name="Person 2",
modality="person",
sha256=create_content_hash("person:unique_test_2"),
size=0,
)
db_session.add(person2)
with pytest.raises(Exception): # Should raise IntegrityError
db_session.commit()

View File

@ -333,7 +333,13 @@ def test_fetch_prs_basic():
page = kwargs.get("params", {}).get("page", 1)
if "/pulls" in url and "/comments" not in url:
# PR list endpoint
if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url:
# Check if this is the diff request
if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff":
response.ok = True
response.text = "+100 lines added\n-50 lines removed"
return response
if page == 1:
response.json.return_value = [
{
@ -355,10 +361,22 @@ def test_fetch_prs_basic():
]
else:
response.json.return_value = []
elif ".diff" in url:
response.ok = True
response.text = "+100 lines added\n-50 lines removed"
elif "/comments" in url:
elif "/pulls/" in url and "/comments" in url:
# Review comments endpoint
response.json.return_value = []
elif "/pulls/" in url and "/reviews" in url:
# Reviews endpoint
response.json.return_value = []
elif "/pulls/" in url and "/files" in url:
# Files endpoint
if page == 1:
response.json.return_value = [
{"filename": "test.py", "status": "added", "additions": 100, "deletions": 50, "patch": "+code"}
]
else:
response.json.return_value = []
elif "/issues/" in url and "/comments" in url:
# Regular comments endpoint
response.json.return_value = []
else:
response.json.return_value = []
@ -376,14 +394,32 @@ def test_fetch_prs_basic():
assert pr["kind"] == "pr"
assert pr["diff_summary"] is not None
assert "100 lines added" in pr["diff_summary"]
# Verify pr_data is populated
assert pr["pr_data"] is not None
assert pr["pr_data"]["additions"] == 100
assert pr["pr_data"]["deletions"] == 50
def test_fetch_prs_merged():
"""Test fetching merged PR."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
def mock_get(url, **kwargs):
"""Route mock responses based on URL."""
response = Mock()
response.headers = {"X-RateLimit-Remaining": "4999"}
response.raise_for_status = Mock()
page = kwargs.get("params", {}).get("page", 1)
# PR list endpoint
if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url:
if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff":
response.ok = True
response.text = ""
return response
if page == 1:
response.json.return_value = [
{
"number": 20,
"title": "Merged PR",
@ -402,17 +438,22 @@ def test_fetch_prs_merged():
"comments": 0,
}
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
else:
response.json.return_value = []
elif "/pulls/" in url and "/comments" in url:
response.json.return_value = []
elif "/pulls/" in url and "/reviews" in url:
response.json.return_value = []
elif "/pulls/" in url and "/files" in url:
response.json.return_value = []
elif "/issues/" in url and "/comments" in url:
response.json.return_value = []
else:
response.json.return_value = []
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty, mock_empty]
return response
with patch.object(requests.Session, "get", side_effect=mock_get):
client = GithubClient(credentials)
prs = list(client.fetch_prs("owner", "repo"))
@ -425,8 +466,22 @@ def test_fetch_prs_stops_at_since():
"""Test that PR fetching stops when reaching older items."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
def mock_get(url, **kwargs):
"""Route mock responses based on URL."""
response = Mock()
response.headers = {"X-RateLimit-Remaining": "4999"}
response.raise_for_status = Mock()
page = kwargs.get("params", {}).get("page", 1)
# PR list endpoint
if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url:
if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff":
response.ok = True
response.text = ""
return response
if page == 1:
response.json.return_value = [
{
"number": 30,
"title": "Recent PR",
@ -462,17 +517,22 @@ def test_fetch_prs_stops_at_since():
"comments": 0,
},
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
else:
response.json.return_value = []
elif "/pulls/" in url and "/comments" in url:
response.json.return_value = []
elif "/pulls/" in url and "/reviews" in url:
response.json.return_value = []
elif "/pulls/" in url and "/files" in url:
response.json.return_value = []
elif "/issues/" in url and "/comments" in url:
response.json.return_value = []
else:
response.json.return_value = []
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty]
return response
with patch.object(requests.Session, "get", side_effect=mock_get):
client = GithubClient(credentials)
since = datetime(2024, 1, 15, tzinfo=timezone.utc)
prs = list(client.fetch_prs("owner", "repo", since=since))
@ -703,3 +763,552 @@ def test_fetch_issues_handles_api_error():
with pytest.raises(requests.HTTPError):
list(client.fetch_issues("owner", "nonexistent"))
# =============================================================================
# Tests for fetch_review_comments
# =============================================================================
def test_fetch_review_comments_basic():
"""Test fetching PR review comments."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
{
"id": 1001,
"user": {"login": "reviewer1"},
"body": "This needs a test",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"diff_hunk": "@@ -40,3 +40,5 @@",
"created_at": "2024-01-01T12:00:00Z",
},
{
"id": 1002,
"user": {"login": "reviewer2"},
"body": "Good refactoring",
"path": "src/utils.py",
"line": 10,
"side": "LEFT",
"diff_hunk": "@@ -8,5 +8,5 @@",
"created_at": "2024-01-02T10:00:00Z",
},
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty]
client = GithubClient(credentials)
comments = client.fetch_review_comments("owner", "repo", 10)
assert len(comments) == 2
assert comments[0]["user"] == "reviewer1"
assert comments[0]["body"] == "This needs a test"
assert comments[0]["path"] == "src/main.py"
assert comments[0]["line"] == 42
assert comments[0]["side"] == "RIGHT"
assert comments[0]["diff_hunk"] == "@@ -40,3 +40,5 @@"
assert comments[1]["user"] == "reviewer2"
def test_fetch_review_comments_ghost_user():
"""Test review comments with deleted user."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
{
"id": 1001,
"user": None, # Deleted user
"body": "Legacy comment",
"path": "file.py",
"line": None, # Line might be None for outdated comments
"side": "RIGHT",
"diff_hunk": "",
"created_at": "2024-01-01T00:00:00Z",
}
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty]
client = GithubClient(credentials)
comments = client.fetch_review_comments("owner", "repo", 10)
assert len(comments) == 1
assert comments[0]["user"] == "ghost"
assert comments[0]["line"] is None
def test_fetch_review_comments_pagination():
"""Test review comment fetching with pagination."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_page1 = Mock()
mock_page1.json.return_value = [
{
"id": i,
"user": {"login": f"user{i}"},
"body": f"Comment {i}",
"path": "file.py",
"line": i,
"side": "RIGHT",
"diff_hunk": "",
"created_at": "2024-01-01T00:00:00Z",
}
for i in range(100)
]
mock_page1.headers = {"X-RateLimit-Remaining": "4999"}
mock_page1.raise_for_status = Mock()
mock_page2 = Mock()
mock_page2.json.return_value = [
{
"id": 100,
"user": {"login": "user100"},
"body": "Final comment",
"path": "file.py",
"line": 100,
"side": "RIGHT",
"diff_hunk": "",
"created_at": "2024-01-01T00:00:00Z",
}
]
mock_page2.headers = {"X-RateLimit-Remaining": "4998"}
mock_page2.raise_for_status = Mock()
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4997"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_page1, mock_page2, mock_empty]
client = GithubClient(credentials)
comments = client.fetch_review_comments("owner", "repo", 10)
assert len(comments) == 101
# =============================================================================
# Tests for fetch_reviews
# =============================================================================
def test_fetch_reviews_basic():
"""Test fetching PR reviews."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
{
"id": 2001,
"user": {"login": "lead_dev"},
"state": "APPROVED",
"body": "LGTM!",
"submitted_at": "2024-01-05T15:00:00Z",
},
{
"id": 2002,
"user": {"login": "qa_engineer"},
"state": "CHANGES_REQUESTED",
"body": "Please add tests",
"submitted_at": "2024-01-04T10:00:00Z",
},
{
"id": 2003,
"user": {"login": "observer"},
"state": "COMMENTED",
"body": None, # Some reviews have no body
"submitted_at": "2024-01-03T08:00:00Z",
},
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty]
client = GithubClient(credentials)
reviews = client.fetch_reviews("owner", "repo", 10)
assert len(reviews) == 3
assert reviews[0]["user"] == "lead_dev"
assert reviews[0]["state"] == "approved" # Lowercased
assert reviews[0]["body"] == "LGTM!"
assert reviews[1]["state"] == "changes_requested"
assert reviews[2]["body"] is None
def test_fetch_reviews_ghost_user():
"""Test reviews with deleted user."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
{
"id": 2001,
"user": None,
"state": "APPROVED",
"body": "Approved by former employee",
"submitted_at": "2024-01-01T00:00:00Z",
}
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty]
client = GithubClient(credentials)
reviews = client.fetch_reviews("owner", "repo", 10)
assert len(reviews) == 1
assert reviews[0]["user"] == "ghost"
# =============================================================================
# Tests for fetch_pr_files
# =============================================================================
def test_fetch_pr_files_basic():
"""Test fetching PR file changes."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
{
"filename": "src/main.py",
"status": "modified",
"additions": 10,
"deletions": 5,
"patch": "@@ -1,5 +1,10 @@\n+new code\n-old code",
},
{
"filename": "src/new_feature.py",
"status": "added",
"additions": 100,
"deletions": 0,
"patch": "@@ -0,0 +1,100 @@\n+entire new file",
},
{
"filename": "old_file.py",
"status": "removed",
"additions": 0,
"deletions": 50,
"patch": "@@ -1,50 +0,0 @@\n-entire old file",
},
{
"filename": "image.png",
"status": "added",
"additions": 0,
"deletions": 0,
# No patch for binary files
},
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty]
client = GithubClient(credentials)
files = client.fetch_pr_files("owner", "repo", 10)
assert len(files) == 4
assert files[0]["filename"] == "src/main.py"
assert files[0]["status"] == "modified"
assert files[0]["additions"] == 10
assert files[0]["deletions"] == 5
assert files[0]["patch"] is not None
assert files[1]["status"] == "added"
assert files[2]["status"] == "removed"
assert files[3]["patch"] is None # Binary file
def test_fetch_pr_files_renamed():
"""Test PR with renamed files."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.json.return_value = [
{
"filename": "new_name.py",
"status": "renamed",
"additions": 0,
"deletions": 0,
"patch": None,
}
]
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
mock_response.raise_for_status = Mock()
mock_empty = Mock()
mock_empty.json.return_value = []
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
mock_empty.raise_for_status = Mock()
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = [mock_response, mock_empty]
client = GithubClient(credentials)
files = client.fetch_pr_files("owner", "repo", 10)
assert len(files) == 1
assert files[0]["status"] == "renamed"
# =============================================================================
# Tests for fetch_pr_diff
# =============================================================================
def test_fetch_pr_diff_success():
"""Test fetching full PR diff."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
diff_text = """diff --git a/file.py b/file.py
index abc123..def456 100644
--- a/file.py
+++ b/file.py
@@ -1,5 +1,10 @@
+import os
+
def main():
- print("old")
+ print("new")
"""
mock_response = Mock()
mock_response.ok = True
mock_response.text = diff_text
with patch.object(requests.Session, "get") as mock_get:
mock_get.return_value = mock_response
client = GithubClient(credentials)
diff = client.fetch_pr_diff("owner", "repo", 10)
assert diff == diff_text
# Verify Accept header was set for diff format
call_kwargs = mock_get.call_args.kwargs
assert call_kwargs["headers"]["Accept"] == "application/vnd.github.diff"
def test_fetch_pr_diff_failure():
"""Test handling diff fetch failure gracefully."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
mock_response = Mock()
mock_response.ok = False
with patch.object(requests.Session, "get") as mock_get:
mock_get.return_value = mock_response
client = GithubClient(credentials)
diff = client.fetch_pr_diff("owner", "repo", 10)
assert diff is None
def test_fetch_pr_diff_exception():
"""Test handling exceptions during diff fetch."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
with patch.object(requests.Session, "get") as mock_get:
mock_get.side_effect = requests.RequestException("Network error")
client = GithubClient(credentials)
diff = client.fetch_pr_diff("owner", "repo", 10)
assert diff is None
# =============================================================================
# Tests for _parse_pr with pr_data
# =============================================================================
def test_parse_pr_fetches_all_pr_data():
"""Test that _parse_pr fetches and includes all PR-specific data."""
credentials = GithubCredentials(auth_type="pat", access_token="token")
pr_raw = {
"number": 42,
"title": "Feature PR",
"body": "PR description",
"state": "open",
"user": {"login": "contributor"},
"labels": [{"name": "enhancement"}],
"assignees": [{"login": "reviewer"}],
"milestone": {"title": "v2.0"},
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-02T00:00:00Z",
"closed_at": None,
"merged_at": None,
}
# Mock responses for all the fetch methods
def mock_get(url, **kwargs):
response = Mock()
response.headers = {"X-RateLimit-Remaining": "4999"}
response.raise_for_status = Mock()
if "/issues/42/comments" in url:
# Regular comments
page = kwargs.get("params", {}).get("page", 1)
if page == 1:
response.json.return_value = [
{
"id": 1,
"user": {"login": "user1"},
"body": "Regular comment",
"created_at": "2024-01-01T10:00:00Z",
"updated_at": "2024-01-01T10:00:00Z",
}
]
else:
response.json.return_value = []
elif "/pulls/42/comments" in url:
# Review comments
page = kwargs.get("params", {}).get("page", 1)
if page == 1:
response.json.return_value = [
{
"id": 101,
"user": {"login": "reviewer1"},
"body": "Review comment",
"path": "src/main.py",
"line": 10,
"side": "RIGHT",
"diff_hunk": "@@ -1,5 +1,10 @@",
"created_at": "2024-01-01T12:00:00Z",
}
]
else:
response.json.return_value = []
elif "/pulls/42/reviews" in url:
# Reviews
page = kwargs.get("params", {}).get("page", 1)
if page == 1:
response.json.return_value = [
{
"id": 201,
"user": {"login": "lead"},
"state": "APPROVED",
"body": "LGTM",
"submitted_at": "2024-01-02T08:00:00Z",
}
]
else:
response.json.return_value = []
elif "/pulls/42/files" in url:
# Files
page = kwargs.get("params", {}).get("page", 1)
if page == 1:
response.json.return_value = [
{
"filename": "src/main.py",
"status": "modified",
"additions": 50,
"deletions": 10,
"patch": "+new\n-old",
},
{
"filename": "tests/test_main.py",
"status": "added",
"additions": 30,
"deletions": 0,
"patch": "+tests",
},
]
else:
response.json.return_value = []
elif "/pulls/42" in url and "diff" in kwargs.get("headers", {}).get(
"Accept", ""
):
# Full diff
response.ok = True
response.text = "diff --git a/src/main.py\n+new code\n-old code"
return response
else:
response.json.return_value = []
return response
with patch.object(requests.Session, "get", side_effect=mock_get):
client = GithubClient(credentials)
result = client._parse_pr("owner", "repo", pr_raw)
# Verify basic fields
assert result["kind"] == "pr"
assert result["number"] == 42
assert result["title"] == "Feature PR"
assert result["author"] == "contributor"
assert len(result["comments"]) == 1
# Verify pr_data
pr_data = result["pr_data"]
assert pr_data is not None
# Verify diff
assert pr_data["diff"] is not None
assert "new code" in pr_data["diff"]
# Verify files
assert len(pr_data["files"]) == 2
assert pr_data["files"][0]["filename"] == "src/main.py"
assert pr_data["files"][0]["additions"] == 50
# Verify stats calculated from files
assert pr_data["additions"] == 80 # 50 + 30
assert pr_data["deletions"] == 10
assert pr_data["changed_files_count"] == 2
# Verify reviews
assert len(pr_data["reviews"]) == 1
assert pr_data["reviews"][0]["user"] == "lead"
assert pr_data["reviews"][0]["state"] == "approved"
# Verify review comments
assert len(pr_data["review_comments"]) == 1
assert pr_data["review_comments"][0]["user"] == "reviewer1"
assert pr_data["review_comments"][0]["path"] == "src/main.py"
# Verify diff_summary is truncated version of full diff
assert result["diff_summary"] == pr_data["diff"][:5000]

View File

@ -1123,3 +1123,521 @@ def test_tag_merging(repo_tags, issue_labels, expected_tags, github_account, db_
item = db_session.query(GithubItem).filter_by(number=60).first()
assert item.tags == expected_tags
# =============================================================================
# Tests for PR data handling
# =============================================================================
@pytest.fixture
def mock_pr_data_with_extended() -> GithubIssueData:
"""Mock PR data with full pr_data dict."""
from memory.parsers.github import GithubPRDataDict
pr_data: GithubPRDataDict = {
"diff": "diff --git a/file.py\n+new line\n-old line",
"files": [
{
"filename": "src/main.py",
"status": "modified",
"additions": 50,
"deletions": 10,
"patch": "+new\n-old",
},
{
"filename": "tests/test_main.py",
"status": "added",
"additions": 30,
"deletions": 0,
"patch": "+test code",
},
],
"additions": 80,
"deletions": 10,
"changed_files_count": 2,
"reviews": [
{
"id": 1001,
"user": "lead_reviewer",
"state": "approved",
"body": "LGTM!",
"submitted_at": "2024-01-02T10:00:00Z",
}
],
"review_comments": [
{
"id": 2001,
"user": "reviewer1",
"body": "Please add a docstring here",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"diff_hunk": "@@ -40,3 +40,5 @@",
"created_at": "2024-01-01T15:00:00Z",
}
],
}
return GithubIssueData(
kind="pr",
number=200,
title="Feature: Add new capability",
body="This PR adds a new capability to the system.",
state="open",
author="contributor",
labels=["enhancement", "needs-review"],
assignees=["reviewer1"],
milestone="v2.0",
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
closed_at=None,
merged_at=None,
github_updated_at=datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc),
comment_count=1,
comments=[
{
"id": 5001,
"author": "maintainer",
"body": "Thanks for the PR!",
"created_at": "2024-01-01T14:00:00Z",
"updated_at": "2024-01-01T14:00:00Z",
}
],
diff_summary="+new\n-old",
project_fields=None,
content_hash="pr_extended_hash",
pr_data=pr_data,
)
def test_build_content_with_review_comments(mock_pr_data_with_extended):
"""Test _build_content includes review comments for PRs."""
content = _build_content(mock_pr_data_with_extended)
# Basic content
assert "# Feature: Add new capability" in content
assert "This PR adds a new capability" in content
# Regular comment
assert "**maintainer**: Thanks for the PR!" in content
# Review comments section
assert "## Code Review Comments" in content
assert "**reviewer1**" in content
assert "Please add a docstring here" in content
assert "`src/main.py`" in content
def test_build_content_pr_without_review_comments():
"""Test _build_content for PR with no review comments."""
data = GithubIssueData(
kind="pr",
number=201,
title="Simple PR",
body="Body",
state="open",
author="user",
labels=[],
assignees=[],
milestone=None,
created_at=datetime.now(timezone.utc),
closed_at=None,
merged_at=None,
github_updated_at=datetime.now(timezone.utc),
comment_count=0,
comments=[],
diff_summary=None,
project_fields=None,
content_hash="hash",
pr_data={
"diff": None,
"files": [],
"additions": 0,
"deletions": 0,
"changed_files_count": 0,
"reviews": [],
"review_comments": [], # Empty
},
)
content = _build_content(data)
assert "# Simple PR" in content
assert "## Code Review Comments" not in content
def test_build_content_issue_no_pr_data():
"""Test _build_content for issue (no pr_data)."""
data = GithubIssueData(
kind="issue",
number=100,
title="Bug Report",
body="There's a bug",
state="open",
author="reporter",
labels=["bug"],
assignees=[],
milestone=None,
created_at=datetime.now(timezone.utc),
closed_at=None,
merged_at=None,
github_updated_at=datetime.now(timezone.utc),
comment_count=0,
comments=[],
diff_summary=None,
project_fields=None,
content_hash="hash",
pr_data=None, # Issues don't have pr_data
)
content = _build_content(data)
assert "# Bug Report" in content
assert "There's a bug" in content
assert "## Code Review Comments" not in content
def test_create_pr_data_function(mock_pr_data_with_extended):
"""Test _create_pr_data creates GithubPRData correctly."""
from memory.workers.tasks.github import _create_pr_data
result = _create_pr_data(mock_pr_data_with_extended)
assert result is not None
assert result.additions == 80
assert result.deletions == 10
assert result.changed_files_count == 2
# Files are stored as JSONB
assert len(result.files) == 2
assert result.files[0]["filename"] == "src/main.py"
# Reviews
assert len(result.reviews) == 1
assert result.reviews[0]["user"] == "lead_reviewer"
assert result.reviews[0]["state"] == "approved"
# Review comments
assert len(result.review_comments) == 1
assert result.review_comments[0]["path"] == "src/main.py"
# Diff is compressed - test the property getter
assert result.diff is not None
assert "new line" in result.diff
def test_create_pr_data_none_for_issue():
"""Test _create_pr_data returns None for issues."""
from memory.workers.tasks.github import _create_pr_data
data = GithubIssueData(
kind="issue",
number=100,
title="Issue",
body="Body",
state="open",
author="user",
labels=[],
assignees=[],
milestone=None,
created_at=datetime.now(timezone.utc),
closed_at=None,
merged_at=None,
github_updated_at=datetime.now(timezone.utc),
comment_count=0,
comments=[],
diff_summary=None,
project_fields=None,
content_hash="hash",
pr_data=None,
)
result = _create_pr_data(data)
assert result is None
def test_serialize_deserialize_with_pr_data(mock_pr_data_with_extended):
"""Test serialization roundtrip preserves pr_data."""
serialized = _serialize_issue_data(mock_pr_data_with_extended)
# Verify pr_data is included in serialized form
assert "pr_data" in serialized
assert serialized["pr_data"]["additions"] == 80
assert len(serialized["pr_data"]["files"]) == 2
assert len(serialized["pr_data"]["reviews"]) == 1
assert len(serialized["pr_data"]["review_comments"]) == 1
# Deserialize and verify
deserialized = _deserialize_issue_data(serialized)
assert deserialized["pr_data"] is not None
assert deserialized["pr_data"]["additions"] == 80
assert deserialized["pr_data"]["deletions"] == 10
assert len(deserialized["pr_data"]["files"]) == 2
assert deserialized["pr_data"]["diff"] == mock_pr_data_with_extended["pr_data"]["diff"]
def test_serialize_deserialize_without_pr_data(mock_issue_data):
"""Test serialization roundtrip for issue without pr_data."""
# Add pr_data=None to the mock (issues don't have it)
issue_with_none = dict(mock_issue_data)
issue_with_none["pr_data"] = None
serialized = _serialize_issue_data(issue_with_none)
assert serialized.get("pr_data") is None
deserialized = _deserialize_issue_data(serialized)
assert deserialized.get("pr_data") is None
def test_sync_github_item_creates_pr_data(
mock_pr_data_with_extended, github_repo, db_session, qdrant
):
"""Test that syncing a PR creates associated GithubPRData."""
serialized = _serialize_issue_data(mock_pr_data_with_extended)
result = github.sync_github_item(github_repo.id, serialized)
assert result["status"] == "processed"
# Query the created item
item = (
db_session.query(GithubItem)
.filter_by(repo_path="testorg/testrepo", number=200, kind="pr")
.first()
)
assert item is not None
assert item.kind == "pr"
# Check pr_data relationship
assert item.pr_data is not None
assert item.pr_data.additions == 80
assert item.pr_data.deletions == 10
assert item.pr_data.changed_files_count == 2
assert len(item.pr_data.files) == 2
assert len(item.pr_data.reviews) == 1
assert len(item.pr_data.review_comments) == 1
# Verify diff decompression works
assert item.pr_data.diff is not None
assert "new line" in item.pr_data.diff
def test_sync_github_item_pr_without_pr_data(github_repo, db_session, qdrant):
"""Test syncing a PR that doesn't have extended pr_data."""
data = GithubIssueData(
kind="pr",
number=202,
title="Legacy PR",
body="PR without extended data",
state="merged",
author="old_contributor",
labels=[],
assignees=[],
milestone=None,
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
closed_at=datetime(2024, 1, 5, tzinfo=timezone.utc),
merged_at=datetime(2024, 1, 5, tzinfo=timezone.utc),
github_updated_at=datetime(2024, 1, 5, tzinfo=timezone.utc),
comment_count=0,
comments=[],
diff_summary="+10 -5",
project_fields=None,
content_hash="legacy_hash",
pr_data=None, # No extended PR data
)
serialized = _serialize_issue_data(data)
result = github.sync_github_item(github_repo.id, serialized)
assert result["status"] == "processed"
item = db_session.query(GithubItem).filter_by(number=202).first()
assert item is not None
assert item.kind == "pr"
assert item.pr_data is None # No pr_data created
def test_sync_github_item_updates_existing_pr_data(github_repo, db_session, qdrant):
"""Test updating an existing PR with new pr_data."""
from memory.common.db.models import GithubPRData
from memory.workers.tasks.content_processing import create_content_hash
# Create initial PR with pr_data
initial_content = "# Initial PR\n\nOriginal body"
existing_item = GithubItem(
repo_path="testorg/testrepo",
repo_id=github_repo.id,
number=300,
kind="pr",
title="Initial PR",
content=initial_content,
state="open",
author="user",
labels=[],
assignees=[],
milestone=None,
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
comment_count=0,
content_hash="initial_hash",
diff_summary="+5 -2",
modality="github",
mime_type="text/markdown",
sha256=create_content_hash(initial_content),
size=len(initial_content),
tags=["github", "test"],
)
# Create initial pr_data
initial_pr_data = GithubPRData(
additions=5,
deletions=2,
changed_files_count=1,
files=[{"filename": "old.py", "status": "modified", "additions": 5, "deletions": 2, "patch": None}],
reviews=[],
review_comments=[],
)
initial_pr_data.diff = "old diff"
existing_item.pr_data = initial_pr_data
db_session.add(existing_item)
db_session.commit()
# Now update with new data
updated_data = GithubIssueData(
kind="pr",
number=300,
title="Updated PR",
body="Updated body with more changes",
state="open",
author="user",
labels=["ready-for-review"],
assignees=["reviewer"],
milestone=None,
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
closed_at=None,
merged_at=None,
github_updated_at=datetime(2024, 1, 5, tzinfo=timezone.utc), # Newer
comment_count=2,
comments=[
{"id": 1, "author": "reviewer", "body": "LGTM", "created_at": "", "updated_at": ""}
],
diff_summary="+50 -10",
project_fields=None,
content_hash="updated_hash", # Different hash triggers update
pr_data={
"diff": "new diff with lots of changes",
"files": [
{"filename": "new.py", "status": "added", "additions": 50, "deletions": 0, "patch": "+code"},
{"filename": "old.py", "status": "modified", "additions": 0, "deletions": 10, "patch": "-code"},
],
"additions": 50,
"deletions": 10,
"changed_files_count": 2,
"reviews": [
{"id": 1, "user": "reviewer", "state": "approved", "body": "Approved!", "submitted_at": ""}
],
"review_comments": [
{"id": 1, "user": "reviewer", "body": "Nice!", "path": "new.py", "line": 10, "side": "RIGHT", "diff_hunk": "", "created_at": ""}
],
},
)
serialized = _serialize_issue_data(updated_data)
result = github.sync_github_item(github_repo.id, serialized)
assert result["status"] == "processed"
# Refresh from DB
db_session.expire_all()
item = db_session.query(GithubItem).filter_by(number=300).first()
assert item.title == "Updated PR"
assert item.pr_data is not None
assert item.pr_data.additions == 50
assert item.pr_data.deletions == 10
assert item.pr_data.changed_files_count == 2
assert len(item.pr_data.files) == 2
assert len(item.pr_data.reviews) == 1
assert len(item.pr_data.review_comments) == 1
assert "new diff" in item.pr_data.diff
def test_sync_github_item_creates_pr_data_for_existing_pr_without(
github_repo, db_session, qdrant
):
"""Test updating a PR that didn't have pr_data to add it."""
from memory.workers.tasks.content_processing import create_content_hash
# Create existing PR without pr_data (legacy data)
initial_content = "# Legacy PR\n\nOriginal"
existing_item = GithubItem(
repo_path="testorg/testrepo",
repo_id=github_repo.id,
number=301,
kind="pr",
title="Legacy PR",
content=initial_content,
state="open",
author="user",
labels=[],
assignees=[],
milestone=None,
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
comment_count=0,
content_hash="legacy_hash",
diff_summary=None,
modality="github",
mime_type="text/markdown",
sha256=create_content_hash(initial_content),
size=len(initial_content),
tags=["github"],
pr_data=None, # No pr_data initially
)
db_session.add(existing_item)
db_session.commit()
# Update with pr_data
updated_data = GithubIssueData(
kind="pr",
number=301,
title="Legacy PR",
body="Original with new review",
state="open",
author="user",
labels=[],
assignees=[],
milestone=None,
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
closed_at=None,
merged_at=None,
github_updated_at=datetime(2024, 1, 2, tzinfo=timezone.utc),
comment_count=0,
comments=[],
diff_summary="+10 -0",
project_fields=None,
content_hash="new_hash", # Different
pr_data={
"diff": "the full diff",
"files": [{"filename": "new.py", "status": "added", "additions": 10, "deletions": 0, "patch": None}],
"additions": 10,
"deletions": 0,
"changed_files_count": 1,
"reviews": [],
"review_comments": [],
},
)
serialized = _serialize_issue_data(updated_data)
result = github.sync_github_item(github_repo.id, serialized)
assert result["status"] == "processed"
db_session.expire_all()
item = db_session.query(GithubItem).filter_by(number=301).first()
# Now should have pr_data
assert item.pr_data is not None
assert item.pr_data.additions == 10
assert item.pr_data.diff == "the full diff"

View File

@ -0,0 +1,404 @@
"""Tests for people Celery tasks."""
import uuid
from contextlib import contextmanager
from unittest.mock import patch, MagicMock
import pytest
from memory.common.db.models import Person
from memory.common.db.models.source_item import Chunk
from memory.workers.tasks import people
from memory.workers.tasks.content_processing import create_content_hash
def _make_mock_chunk(source_id: int) -> Chunk:
"""Create a mock chunk for testing with a unique ID."""
return Chunk(
id=str(uuid.uuid4()),
content="test chunk content",
embedding_model="test-model",
vector=[0.1] * 1024,
item_metadata={"source_id": source_id, "tags": ["test"]},
collection_name="person",
)
@pytest.fixture
def mock_make_session(db_session):
"""Mock make_session and embedding functions for task tests."""
@contextmanager
def _mock_session():
yield db_session
with patch("memory.workers.tasks.people.make_session", _mock_session):
# Mock embedding to return a fake chunk
with patch(
"memory.common.embedding.embed_source_item",
side_effect=lambda item: [_make_mock_chunk(item.id or 1)],
):
# Mock push_to_qdrant to do nothing
with patch("memory.workers.tasks.content_processing.push_to_qdrant"):
yield db_session
@pytest.fixture
def person_data():
"""Standard person test data."""
return {
"identifier": "alice_chen",
"display_name": "Alice Chen",
"aliases": ["@alice_c", "alice.chen@work.com"],
"contact_info": {"email": "alice@example.com", "phone": "555-1234"},
"tags": ["work", "engineering"],
"notes": "Tech lead on Platform team.",
}
@pytest.fixture
def minimal_person_data():
"""Minimal person test data."""
return {
"identifier": "bob_smith",
"display_name": "Bob Smith",
}
def test_sync_person_success(person_data, mock_make_session, qdrant):
"""Test successful person sync."""
result = people.sync_person(**person_data)
# Verify the Person was created in the database
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
assert person is not None
assert person.identifier == "alice_chen"
assert person.display_name == "Alice Chen"
assert person.aliases == ["@alice_c", "alice.chen@work.com"]
assert person.contact_info == {"email": "alice@example.com", "phone": "555-1234"}
assert person.tags == ["work", "engineering"]
assert person.content == "Tech lead on Platform team."
assert person.modality == "person"
# Verify the result
assert result["status"] == "processed"
assert "person_id" in result
def test_sync_person_minimal_data(minimal_person_data, mock_make_session, qdrant):
"""Test person sync with minimal required data."""
result = people.sync_person(**minimal_person_data)
person = mock_make_session.query(Person).filter_by(identifier="bob_smith").first()
assert person is not None
assert person.identifier == "bob_smith"
assert person.display_name == "Bob Smith"
assert person.aliases == []
assert person.contact_info == {}
assert person.tags == []
assert person.content is None
assert result["status"] == "processed"
def test_sync_person_already_exists(person_data, mock_make_session, qdrant):
"""Test sync when person already exists."""
# Create the person first
sha256 = create_content_hash(f"person:{person_data['identifier']}")
existing_person = Person(
identifier=person_data["identifier"],
display_name=person_data["display_name"],
aliases=person_data["aliases"],
contact_info=person_data["contact_info"],
tags=person_data["tags"],
content=person_data["notes"],
modality="person",
mime_type="text/plain",
sha256=sha256,
size=len(person_data["notes"]),
)
mock_make_session.add(existing_person)
mock_make_session.commit()
# Try to sync again
result = people.sync_person(**person_data)
assert result["status"] == "already_exists"
assert result["person_id"] == existing_person.id
# Verify no duplicate was created
count = mock_make_session.query(Person).filter_by(identifier="alice_chen").count()
assert count == 1
def test_update_person_display_name(person_data, mock_make_session, qdrant):
"""Test updating display name."""
# Create person first
people.sync_person(**person_data)
# Update display name
result = people.update_person(
identifier="alice_chen",
display_name="Alice M. Chen",
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
assert person.display_name == "Alice M. Chen"
# Other fields unchanged
assert person.aliases == ["@alice_c", "alice.chen@work.com"]
def test_update_person_merge_aliases(person_data, mock_make_session, qdrant):
"""Test that aliases are merged, not replaced."""
# Create person first
people.sync_person(**person_data)
# Update with new aliases
result = people.update_person(
identifier="alice_chen",
aliases=["@alice_chen", "alice@company.com"],
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
# Should be union of old and new
assert set(person.aliases) == {
"@alice_c",
"alice.chen@work.com",
"@alice_chen",
"alice@company.com",
}
def test_update_person_merge_contact_info(person_data, mock_make_session, qdrant):
"""Test that contact_info is deep merged."""
# Create person first
people.sync_person(**person_data)
# Update with new contact info
result = people.update_person(
identifier="alice_chen",
contact_info={"twitter": "@alice_c", "phone": "555-5678"}, # Update existing
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
# Should have all keys
assert person.contact_info["email"] == "alice@example.com" # Original
assert person.contact_info["phone"] == "555-5678" # Updated
assert person.contact_info["twitter"] == "@alice_c" # New
def test_update_person_merge_tags(person_data, mock_make_session, qdrant):
"""Test that tags are merged, not replaced."""
# Create person first
people.sync_person(**person_data)
# Update with new tags
result = people.update_person(
identifier="alice_chen",
tags=["climbing", "london"],
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
# Should be union of old and new
assert set(person.tags) == {"work", "engineering", "climbing", "london"}
def test_update_person_append_notes(person_data, mock_make_session, qdrant):
"""Test that notes are appended by default."""
# Create person first
people.sync_person(**person_data)
# Update with new notes
result = people.update_person(
identifier="alice_chen",
notes="Also enjoys rock climbing.",
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
# Should be appended with separator
assert "Tech lead on Platform team." in person.content
assert "Also enjoys rock climbing." in person.content
assert "---" in person.content
def test_update_person_replace_notes(person_data, mock_make_session, qdrant):
"""Test replacing notes instead of appending."""
# Create person first
people.sync_person(**person_data)
# Replace notes
result = people.update_person(
identifier="alice_chen",
notes="Completely new notes.",
replace_notes=True,
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
assert person.content == "Completely new notes."
assert "Tech lead" not in person.content
def test_update_person_not_found(mock_make_session, qdrant):
"""Test updating a person that doesn't exist."""
result = people.update_person(
identifier="nonexistent_person",
display_name="New Name",
)
assert result["status"] == "not_found"
assert result["identifier"] == "nonexistent_person"
def test_update_person_no_changes(person_data, mock_make_session, qdrant):
"""Test update with no actual changes."""
# Create person first
people.sync_person(**person_data)
# Update with nothing
result = people.update_person(identifier="alice_chen")
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
# Should be unchanged
assert person.display_name == "Alice Chen"
@pytest.mark.parametrize(
"identifier,display_name,tags",
[
("john_doe", "John Doe", []),
("jane_smith", "Jane Smith", ["friend"]),
("bob_jones", "Bob Jones", ["work", "climbing", "london"]),
],
)
def test_sync_person_various_configurations(identifier, display_name, tags, mock_make_session, qdrant):
"""Test sync_person with various configurations."""
result = people.sync_person(
identifier=identifier,
display_name=display_name,
tags=tags,
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier=identifier).first()
assert person is not None
assert person.display_name == display_name
assert person.tags == tags
def test_deep_merge_helper():
"""Test the _deep_merge helper function."""
base = {
"a": 1,
"b": {"c": 2, "d": 3},
"e": 4,
}
updates = {
"b": {"c": 5, "f": 6},
"g": 7,
}
result = people._deep_merge(base, updates)
assert result == {
"a": 1,
"b": {"c": 5, "d": 3, "f": 6},
"e": 4,
"g": 7,
}
def test_deep_merge_nested():
"""Test deep merge with deeply nested structures."""
base = {
"level1": {
"level2": {
"level3": {"a": 1},
},
},
}
updates = {
"level1": {
"level2": {
"level3": {"b": 2},
},
},
}
result = people._deep_merge(base, updates)
assert result == {
"level1": {
"level2": {
"level3": {"a": 1, "b": 2},
},
},
}
def test_sync_person_unicode(mock_make_session, qdrant):
"""Test sync_person with unicode content."""
result = people.sync_person(
identifier="unicode_person",
display_name="日本語 名前",
notes="Привет мир 🌍",
tags=["日本語", "emoji"],
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="unicode_person").first()
assert person is not None
assert person.display_name == "日本語 名前"
assert person.content == "Привет мир 🌍"
def test_sync_person_empty_notes(mock_make_session, qdrant):
"""Test sync_person with empty notes."""
result = people.sync_person(
identifier="empty_notes_person",
display_name="Empty Notes Person",
notes="",
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="empty_notes_person").first()
assert person is not None
assert person.content == ""
def test_update_person_first_notes(mock_make_session, qdrant):
"""Test adding notes to a person who had no notes."""
# Create person without notes
people.sync_person(
identifier="no_notes_person",
display_name="No Notes Person",
)
# Add notes
result = people.update_person(
identifier="no_notes_person",
notes="First notes!",
)
assert result["status"] == "processed"
person = mock_make_session.query(Person).filter_by(identifier="no_notes_person").first()
assert person.content == "First notes!"
# Should not have separator when there were no previous notes
assert "---" not in person.content