From 47629fc5fbbc5399c23bd5a45a8cb8dca0d6da5d Mon Sep 17 00:00:00 2001 From: mruwnik Date: Wed, 24 Dec 2025 13:25:34 +0000 Subject: [PATCH] add PRs and People --- .gitignore | 4 + .../versions/20251224_120000_add_people.py | 56 + .../20251224_150000_add_github_pr_data.py | 57 + docker-compose.yaml | 2 +- docker/workers/Dockerfile | 2 +- src/memory/api/MCP/__init__.py | 2 + src/memory/api/MCP/github.py | 518 ++++++++ src/memory/api/MCP/people.py | 257 ++++ src/memory/common/celery_app.py | 6 + src/memory/common/db/models/__init__.py | 10 + src/memory/common/db/models/people.py | 95 ++ src/memory/common/db/models/source_items.py | 66 + src/memory/parsers/github.py | 227 +++- src/memory/workers/tasks/github.py | 67 +- src/memory/workers/tasks/people.py | 145 +++ tests/memory/api/MCP/__init__.py | 0 tests/memory/api/MCP/test_github.py | 1064 +++++++++++++++++ tests/memory/api/MCP/test_people.py | 482 ++++++++ tests/memory/api/__init__.py | 0 tests/memory/common/db/models/test_people.py | 249 ++++ tests/memory/parsers/test_github.py | 765 ++++++++++-- .../memory/workers/tasks/test_github_tasks.py | 518 ++++++++ .../memory/workers/tasks/test_people_tasks.py | 404 +++++++ 23 files changed, 4902 insertions(+), 94 deletions(-) create mode 100644 db/migrations/versions/20251224_120000_add_people.py create mode 100644 db/migrations/versions/20251224_150000_add_github_pr_data.py create mode 100644 src/memory/api/MCP/github.py create mode 100644 src/memory/api/MCP/people.py create mode 100644 src/memory/common/db/models/people.py create mode 100644 src/memory/workers/tasks/people.py create mode 100644 tests/memory/api/MCP/__init__.py create mode 100644 tests/memory/api/MCP/test_github.py create mode 100644 tests/memory/api/MCP/test_people.py create mode 100644 tests/memory/api/__init__.py create mode 100644 tests/memory/common/db/models/test_people.py create mode 100644 tests/memory/workers/tasks/test_people_tasks.py diff --git a/.gitignore b/.gitignore index 72cb856..5649f8e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ Books +books.md +clean_books +scripts + CLAUDE.md memory_files venv diff --git a/db/migrations/versions/20251224_120000_add_people.py b/db/migrations/versions/20251224_120000_add_people.py new file mode 100644 index 0000000..a22f276 --- /dev/null +++ b/db/migrations/versions/20251224_120000_add_people.py @@ -0,0 +1,56 @@ +"""Add people tracking + +Revision ID: c9d0e1f2a3b4 +Revises: b7c8d9e0f1a2 +Create Date: 2025-12-24 12:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = "c9d0e1f2a3b4" +down_revision: Union[str, None] = "b7c8d9e0f1a2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "people", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("identifier", sa.Text(), nullable=False), + sa.Column("display_name", sa.Text(), nullable=False), + sa.Column( + "aliases", + postgresql.ARRAY(sa.Text()), + server_default="{}", + nullable=False, + ), + sa.Column( + "contact_info", + postgresql.JSONB(astext_type=sa.Text()), + server_default="{}", + nullable=False, + ), + sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("identifier"), + ) + op.create_index("person_identifier_idx", "people", ["identifier"], unique=False) + op.create_index("person_display_name_idx", "people", ["display_name"], unique=False) + op.create_index( + "person_aliases_idx", "people", ["aliases"], unique=False, postgresql_using="gin" + ) + + +def downgrade() -> None: + op.drop_index("person_aliases_idx", table_name="people") + op.drop_index("person_display_name_idx", table_name="people") + op.drop_index("person_identifier_idx", table_name="people") + op.drop_table("people") diff --git a/db/migrations/versions/20251224_150000_add_github_pr_data.py b/db/migrations/versions/20251224_150000_add_github_pr_data.py new file mode 100644 index 0000000..5b27f27 --- /dev/null +++ b/db/migrations/versions/20251224_150000_add_github_pr_data.py @@ -0,0 +1,57 @@ +"""Add github_pr_data table for PR-specific data + +Revision ID: d0e1f2a3b4c5 +Revises: c9d0e1f2a3b4 +Create Date: 2025-12-24 15:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = "d0e1f2a3b4c5" +down_revision: Union[str, None] = "c9d0e1f2a3b4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "github_pr_data", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("github_item_id", sa.BigInteger(), nullable=False), + # Diff stored compressed with zlib + sa.Column("diff_compressed", sa.LargeBinary(), nullable=True), + # File changes as structured data + # [{filename, status, additions, deletions, patch?}] + sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + # Stats + sa.Column("additions", sa.Integer(), nullable=True), + sa.Column("deletions", sa.Integer(), nullable=True), + sa.Column("changed_files_count", sa.Integer(), nullable=True), + # Reviews - [{user, state, body, submitted_at}] + sa.Column("reviews", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + # Review comments (line-by-line code comments) + # [{user, body, path, line, diff_hunk, created_at}] + sa.Column( + "review_comments", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.ForeignKeyConstraint( + ["github_item_id"], ["github_item.id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("github_item_id"), + ) + op.create_index( + "github_pr_data_item_idx", "github_pr_data", ["github_item_id"], unique=True + ) + + +def downgrade() -> None: + op.drop_index("github_pr_data_item_idx", table_name="github_pr_data") + op.drop_table("github_pr_data") diff --git a/docker-compose.yaml b/docker-compose.yaml index f9ca190..5545432 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -206,7 +206,7 @@ services: <<: *worker-base environment: <<: *worker-env - QUEUES: "backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance,notes,scheduler" + QUEUES: "backup,blogs,comic,discord,ebooks,email,forums,github,people,photo_embed,maintenance,notes,scheduler" ingest-hub: <<: *worker-base diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index fa64a97..525ae81 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \ git config --global user.name "${GIT_USER_NAME}" # Default queues to process -ENV QUEUES="backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance" +ENV QUEUES="backup,blogs,comic,discord,ebooks,email,forums,github,people,photo_embed,maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["./entry.sh"] \ No newline at end of file diff --git a/src/memory/api/MCP/__init__.py b/src/memory/api/MCP/__init__.py index 081b81c..07f0efa 100644 --- a/src/memory/api/MCP/__init__.py +++ b/src/memory/api/MCP/__init__.py @@ -4,3 +4,5 @@ import memory.api.MCP.metadata import memory.api.MCP.schedules import memory.api.MCP.books import memory.api.MCP.manifest +import memory.api.MCP.github +import memory.api.MCP.people diff --git a/src/memory/api/MCP/github.py b/src/memory/api/MCP/github.py new file mode 100644 index 0000000..859c224 --- /dev/null +++ b/src/memory/api/MCP/github.py @@ -0,0 +1,518 @@ +"""MCP tools for GitHub issue tracking and management.""" + +import logging +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import Text, case, desc, func, asc +from sqlalchemy import cast as sql_cast +from sqlalchemy.dialects.postgresql import ARRAY + +from memory.api.MCP.base import mcp +from memory.api.search.search import search +from memory.api.search.types import SearchConfig, SearchFilters +from memory.common import extract +from memory.common.db.connection import make_session +from memory.common.db.models import GithubItem + +logger = logging.getLogger(__name__) + + +def _build_github_url(repo_path: str, number: int | None, kind: str) -> str: + """Build GitHub URL from repo path and issue/PR number.""" + if number is None: + return f"https://github.com/{repo_path}" + url_type = "pull" if kind == "pr" else "issues" + return f"https://github.com/{repo_path}/{url_type}/{number}" + + +def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[str, Any]: + """Serialize a GithubItem to a dict for API response.""" + result = { + "id": item.id, + "number": item.number, + "kind": item.kind, + "repo_path": item.repo_path, + "title": item.title, + "state": item.state, + "author": item.author, + "assignees": item.assignees or [], + "labels": item.labels or [], + "milestone": item.milestone, + "project_status": item.project_status, + "project_priority": item.project_priority, + "project_fields": item.project_fields, + "comment_count": item.comment_count, + "created_at": item.created_at.isoformat() if item.created_at else None, + "closed_at": item.closed_at.isoformat() if item.closed_at else None, + "merged_at": item.merged_at.isoformat() if item.merged_at else None, + "github_updated_at": ( + item.github_updated_at.isoformat() if item.github_updated_at else None + ), + "url": _build_github_url(item.repo_path, item.number, item.kind), + } + if include_content: + result["content"] = item.content + # Include PR-specific data if available + if item.kind == "pr" and item.pr_data: + result["pr_data"] = { + "additions": item.pr_data.additions, + "deletions": item.pr_data.deletions, + "changed_files_count": item.pr_data.changed_files_count, + "files": item.pr_data.files, + "reviews": item.pr_data.reviews, + "review_comments": item.pr_data.review_comments, + "diff": item.pr_data.diff, # Decompressed via property + } + return result + + +@mcp.tool() +async def list_github_issues( + repo: str | None = None, + assignee: str | None = None, + author: str | None = None, + state: str | None = None, + kind: str | None = None, + labels: list[str] | None = None, + project_status: str | None = None, + project_field: dict[str, str] | None = None, + updated_since: str | None = None, + updated_before: str | None = None, + limit: int = 50, + order_by: str = "updated", +) -> list[dict]: + """ + List GitHub issues and PRs with flexible filtering. + Use for daily triage, finding assigned issues, tracking stale issues, etc. + + Args: + repo: Filter by repository path (e.g., "owner/name") + assignee: Filter by assignee username + author: Filter by author username + state: Filter by state: "open", "closed", "merged" (default: all) + kind: Filter by type: "issue" or "pr" (default: both) + labels: Filter by GitHub labels (matches ANY label in list) + project_status: Filter by project status (e.g., "In Progress", "Backlog") + project_field: Filter by project field values (e.g., {"EquiStamp.Client": "Redwood"}) + updated_since: ISO date - only issues updated after this time + updated_before: ISO date - only issues updated before this (for finding stale issues) + limit: Maximum results (default 50, max 200) + order_by: Sort order: "updated", "created", or "number" (default: "updated" descending) + + Returns: List of issues with id, number, title, state, assignees, labels, project_fields, timestamps, url + """ + logger.info(f"list_github_issues called: repo={repo}, assignee={assignee}, state={state}") + + limit = min(limit, 200) + + with make_session() as session: + query = session.query(GithubItem) + + # Apply filters + if repo: + query = query.filter(GithubItem.repo_path == repo) + + if assignee: + query = query.filter(GithubItem.assignees.any(assignee)) + + if author: + query = query.filter(GithubItem.author == author) + + if state: + query = query.filter(GithubItem.state == state) + + if kind: + query = query.filter(GithubItem.kind == kind) + else: + # Exclude comments by default, only show issues and PRs + query = query.filter(GithubItem.kind.in_(["issue", "pr"])) + + if labels: + # Match any label in the list using PostgreSQL array overlap + query = query.filter( + GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text))) + ) + + if project_status: + query = query.filter(GithubItem.project_status == project_status) + + if project_field: + for key, value in project_field.items(): + query = query.filter( + GithubItem.project_fields[key].astext == value + ) + + if updated_since: + since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00")) + query = query.filter(GithubItem.github_updated_at >= since_dt) + + if updated_before: + before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00")) + query = query.filter(GithubItem.github_updated_at <= before_dt) + + # Apply ordering + if order_by == "created": + query = query.order_by(desc(GithubItem.created_at)) + elif order_by == "number": + query = query.order_by(desc(GithubItem.number)) + else: # default: updated + query = query.order_by(desc(GithubItem.github_updated_at)) + + query = query.limit(limit) + + items = query.all() + + return [_serialize_issue(item) for item in items] + + +@mcp.tool() +async def search_github_issues( + query: str, + repo: str | None = None, + state: str | None = None, + kind: str | None = None, + limit: int = 20, +) -> list[dict]: + """ + Search GitHub issues using natural language. + Searches across issue titles, bodies, and comments. + + Args: + query: Natural language search query (e.g., "authentication bug", "database migration") + repo: Optional filter by repository path + state: Optional filter: "open", "closed", "merged" + kind: Optional filter: "issue" or "pr" + limit: Maximum results (default 20, max 100) + + Returns: List of matching issues with search score + """ + logger.info(f"search_github_issues called: query={query}, repo={repo}") + + limit = min(limit, 100) + + # Pre-filter source_ids if repo/state/kind filters are specified + source_ids = None + if repo or state or kind: + with make_session() as session: + q = session.query(GithubItem.id) + if repo: + q = q.filter(GithubItem.repo_path == repo) + if state: + q = q.filter(GithubItem.state == state) + if kind: + q = q.filter(GithubItem.kind == kind) + else: + q = q.filter(GithubItem.kind.in_(["issue", "pr"])) + source_ids = [item.id for item in q.all()] + + # Use the existing search infrastructure + data = extract.extract_text(query, skip_summary=True) + config = SearchConfig(limit=limit, previews=True) + filters = SearchFilters() + if source_ids is not None: + filters["source_ids"] = source_ids + + results = await search( + data, + modalities={"github"}, + filters=filters, + config=config, + ) + + # Fetch full issue details for the results + output = [] + with make_session() as session: + for result in results: + item = session.get(GithubItem, result.id) + if item: + serialized = _serialize_issue(item) + serialized["search_score"] = result.score + output.append(serialized) + + return output + + +@mcp.tool() +async def github_issue_details( + repo: str, + number: int, +) -> dict: + """ + Get full details of a specific GitHub issue or PR including all comments. + + Args: + repo: Repository path (e.g., "owner/name") + number: Issue or PR number + + Returns: Full issue details including content (body + comments), project fields, timestamps. + For PRs, also includes pr_data with: diff (full), files changed, reviews, review comments. + """ + logger.info(f"github_issue_details called: repo={repo}, number={number}") + + with make_session() as session: + item = ( + session.query(GithubItem) + .filter( + GithubItem.repo_path == repo, + GithubItem.number == number, + GithubItem.kind.in_(["issue", "pr"]), + ) + .first() + ) + + if not item: + raise ValueError(f"Issue #{number} not found in {repo}") + + return _serialize_issue(item, include_content=True) + + +@mcp.tool() +async def github_work_summary( + since: str, + until: str | None = None, + group_by: str = "client", + repo: str | None = None, +) -> dict: + """ + Summarize GitHub work activity for billing and time tracking. + Groups issues by client, author, status, or repository. + + Args: + since: ISO date - start of period (e.g., "2025-12-16") + until: ISO date - end of period (default: now) + group_by: How to group results: "client", "status", "author", "repo", "task_type" + repo: Optional filter by repository path + + Returns: Summary with grouped counts and sample issues for each group + """ + logger.info(f"github_work_summary called: since={since}, group_by={group_by}") + + since_dt = datetime.fromisoformat(since.replace("Z", "+00:00")) + if until: + until_dt = datetime.fromisoformat(until.replace("Z", "+00:00")) + else: + until_dt = datetime.now(timezone.utc) + + # Map group_by to SQL expression + group_mappings = { + "client": GithubItem.project_fields["EquiStamp.Client"].astext, + "status": GithubItem.project_status, + "author": GithubItem.author, + "repo": GithubItem.repo_path, + "task_type": GithubItem.project_fields["EquiStamp.Task Type"].astext, + } + + if group_by not in group_mappings: + raise ValueError( + f"Invalid group_by: {group_by}. Must be one of: {list(group_mappings.keys())}" + ) + + group_col = group_mappings[group_by] + + with make_session() as session: + # Build base query for the period + base_query = session.query(GithubItem).filter( + GithubItem.github_updated_at >= since_dt, + GithubItem.github_updated_at <= until_dt, + GithubItem.kind.in_(["issue", "pr"]), + ) + + if repo: + base_query = base_query.filter(GithubItem.repo_path == repo) + + # Get aggregated counts by group + agg_query = ( + session.query( + group_col.label("group_name"), + func.count(GithubItem.id).label("total"), + func.count(case((GithubItem.kind == "issue", 1))).label("issue_count"), + func.count(case((GithubItem.kind == "pr", 1))).label("pr_count"), + func.count( + case((GithubItem.state.in_(["closed", "merged"]), 1)) + ).label("closed_count"), + ) + .filter( + GithubItem.github_updated_at >= since_dt, + GithubItem.github_updated_at <= until_dt, + GithubItem.kind.in_(["issue", "pr"]), + ) + .group_by(group_col) + .order_by(desc("total")) + ) + + if repo: + agg_query = agg_query.filter(GithubItem.repo_path == repo) + + groups = agg_query.all() + + # Build summary with sample issues for each group + summary = [] + total_issues = 0 + total_prs = 0 + + for group_name, total, issue_count, pr_count, closed_count in groups: + if group_name is None: + group_name = "(unset)" + + total_issues += issue_count + total_prs += pr_count + + # Get sample issues for this group + sample_query = base_query.filter(group_col == group_name).limit(5) + samples = [ + { + "number": item.number, + "title": item.title, + "repo_path": item.repo_path, + "state": item.state, + "url": _build_github_url(item.repo_path, item.number, item.kind), + } + for item in sample_query.all() + ] + + summary.append( + { + "group": group_name, + "total": total, + "issue_count": issue_count, + "pr_count": pr_count, + "closed_count": closed_count, + "issues": samples, + } + ) + + return { + "period": { + "since": since_dt.isoformat(), + "until": until_dt.isoformat(), + }, + "group_by": group_by, + "summary": summary, + "total_issues": total_issues, + "total_prs": total_prs, + } + + +@mcp.tool() +async def github_repo_overview( + repo: str, +) -> dict: + """ + Get an overview of a GitHub repository's issues and PRs. + Shows counts, status breakdown, top assignees, and labels. + + Args: + repo: Repository path (e.g., "EquiStamp/equistamp" or "owner/name") + + Returns: Repository statistics including counts, status breakdown, top assignees, labels + """ + logger.info(f"github_repo_overview called: repo={repo}") + + with make_session() as session: + # Base query for this repo + base_query = session.query(GithubItem).filter( + GithubItem.repo_path == repo, + GithubItem.kind.in_(["issue", "pr"]), + ) + + # Get total counts + counts_query = session.query( + func.count(GithubItem.id).label("total"), + func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"), + func.count( + case(((GithubItem.kind == "issue") & (GithubItem.state == "open"), 1)) + ).label("open_issues"), + func.count( + case(((GithubItem.kind == "issue") & (GithubItem.state == "closed"), 1)) + ).label("closed_issues"), + func.count(case((GithubItem.kind == "pr", 1))).label("total_prs"), + func.count( + case(((GithubItem.kind == "pr") & (GithubItem.state == "open"), 1)) + ).label("open_prs"), + func.count( + case(((GithubItem.kind == "pr") & (GithubItem.merged_at.isnot(None)), 1)) + ).label("merged_prs"), + func.max(GithubItem.github_updated_at).label("last_updated"), + ).filter( + GithubItem.repo_path == repo, + GithubItem.kind.in_(["issue", "pr"]), + ) + + counts = counts_query.first() + + # Status breakdown (for project_status) + status_query = ( + session.query( + GithubItem.project_status.label("status"), + func.count(GithubItem.id).label("count"), + ) + .filter( + GithubItem.repo_path == repo, + GithubItem.kind.in_(["issue", "pr"]), + GithubItem.project_status.isnot(None), + ) + .group_by(GithubItem.project_status) + .order_by(desc("count")) + ) + + status_breakdown = {row.status: row.count for row in status_query.all()} + + # Top assignees (open issues only) + assignee_query = ( + session.query( + func.unnest(GithubItem.assignees).label("assignee"), + func.count(GithubItem.id).label("count"), + ) + .filter( + GithubItem.repo_path == repo, + GithubItem.kind.in_(["issue", "pr"]), + GithubItem.state == "open", + ) + .group_by("assignee") + .order_by(desc("count")) + .limit(10) + ) + + top_assignees = [ + {"username": row.assignee, "open_count": row.count} + for row in assignee_query.all() + ] + + # Label counts + label_query = ( + session.query( + func.unnest(GithubItem.labels).label("label"), + func.count(GithubItem.id).label("count"), + ) + .filter( + GithubItem.repo_path == repo, + GithubItem.kind.in_(["issue", "pr"]), + ) + .group_by("label") + .order_by(desc("count")) + .limit(20) + ) + + labels = {row.label: row.count for row in label_query.all()} + + return { + "repo_path": repo, + "counts": { + "total": counts.total if counts else 0, + "total_issues": counts.total_issues if counts else 0, + "open_issues": counts.open_issues if counts else 0, + "closed_issues": counts.closed_issues if counts else 0, + "total_prs": counts.total_prs if counts else 0, + "open_prs": counts.open_prs if counts else 0, + "merged_prs": counts.merged_prs if counts else 0, + }, + "status_breakdown": status_breakdown, + "top_assignees": top_assignees, + "labels": labels, + "last_updated": ( + counts.last_updated.isoformat() + if counts and counts.last_updated + else None + ), + } diff --git a/src/memory/api/MCP/people.py b/src/memory/api/MCP/people.py new file mode 100644 index 0000000..9982317 --- /dev/null +++ b/src/memory/api/MCP/people.py @@ -0,0 +1,257 @@ +""" +MCP tools for tracking people. +""" + +import logging +from typing import Any + +from sqlalchemy import Text +from sqlalchemy import cast as sql_cast +from sqlalchemy.dialects.postgresql import ARRAY + +from memory.api.MCP.base import mcp +from memory.common.db.connection import make_session +from memory.common.db.models import Person +from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON +from memory.common.celery_app import app as celery_app +from memory.common import settings + +logger = logging.getLogger(__name__) + + +def _person_to_dict(person: Person) -> dict[str, Any]: + """Convert a Person model to a dictionary for API responses.""" + return { + "identifier": person.identifier, + "display_name": person.display_name, + "aliases": list(person.aliases or []), + "contact_info": dict(person.contact_info or {}), + "tags": list(person.tags or []), + "notes": person.content, + "created_at": person.inserted_at.isoformat() if person.inserted_at else None, + } + + +@mcp.tool() +async def add_person( + identifier: str, + display_name: str, + aliases: list[str] | None = None, + contact_info: dict | None = None, + tags: list[str] | None = None, + notes: str | None = None, +) -> dict: + """ + Add a new person to track. + + Args: + identifier: Unique slug for the person (e.g., "alice_chen") + display_name: Human-readable name (e.g., "Alice Chen") + aliases: Alternative names/handles (e.g., ["@alice_c", "alice.chen@work.com"]) + contact_info: Contact information as a dict (e.g., {"email": "...", "phone": "..."}) + tags: Categorization tags (e.g., ["work", "friend", "climbing"]) + notes: Free-form notes about the person + + Returns: + Task status with task_id + + Example: + add_person( + identifier="alice_chen", + display_name="Alice Chen", + aliases=["@alice_c"], + contact_info={"email": "alice@example.com"}, + tags=["work", "engineering"], + notes="Tech lead on Platform team" + ) + """ + logger.info(f"MCP: Adding person: {identifier}") + + # Check if person already exists + with make_session() as session: + existing = session.query(Person).filter(Person.identifier == identifier).first() + if existing: + raise ValueError(f"Person with identifier '{identifier}' already exists") + + task = celery_app.send_task( + SYNC_PERSON, + queue=f"{settings.CELERY_QUEUE_PREFIX}-people", + kwargs={ + "identifier": identifier, + "display_name": display_name, + "aliases": aliases, + "contact_info": contact_info, + "tags": tags, + "notes": notes, + }, + ) + + return { + "task_id": task.id, + "status": "queued", + "identifier": identifier, + } + + +@mcp.tool() +async def update_person_info( + identifier: str, + display_name: str | None = None, + aliases: list[str] | None = None, + contact_info: dict | None = None, + tags: list[str] | None = None, + notes: str | None = None, + replace_notes: bool = False, +) -> dict: + """ + Update information about a person with merge semantics. + + This tool MERGES new information with existing data rather than replacing it: + - display_name: Replaces existing value + - aliases: Adds new aliases (union with existing) + - contact_info: Deep merges (adds new keys, updates existing keys, never deletes) + - tags: Adds new tags (union with existing) + - notes: Appends to existing notes (or replaces if replace_notes=True) + + Args: + identifier: The person's unique identifier + display_name: New display name (replaces existing) + aliases: Additional aliases to add + contact_info: Additional contact info to merge + tags: Additional tags to add + notes: Notes to append (or replace if replace_notes=True) + replace_notes: If True, replace notes instead of appending + + Returns: + Task status with task_id + + Example: + # Add new contact info without losing existing data + update_person_info( + identifier="alice_chen", + contact_info={"phone": "555-1234"}, # Added to existing + notes="Enjoys rock climbing" # Appended to existing notes + ) + """ + logger.info(f"MCP: Updating person: {identifier}") + + # Verify person exists + with make_session() as session: + person = session.query(Person).filter(Person.identifier == identifier).first() + if not person: + raise ValueError(f"Person with identifier '{identifier}' not found") + + task = celery_app.send_task( + UPDATE_PERSON, + queue=f"{settings.CELERY_QUEUE_PREFIX}-people", + kwargs={ + "identifier": identifier, + "display_name": display_name, + "aliases": aliases, + "contact_info": contact_info, + "tags": tags, + "notes": notes, + "replace_notes": replace_notes, + }, + ) + + return { + "task_id": task.id, + "status": "queued", + "identifier": identifier, + } + + +@mcp.tool() +async def get_person(identifier: str) -> dict | None: + """ + Get a person by their identifier. + + Args: + identifier: The person's unique identifier + + Returns: + The person record, or None if not found + """ + logger.info(f"MCP: Getting person: {identifier}") + + with make_session() as session: + person = session.query(Person).filter(Person.identifier == identifier).first() + if not person: + return None + return _person_to_dict(person) + + +@mcp.tool() +async def list_people( + tags: list[str] | None = None, + search: str | None = None, + limit: int = 50, +) -> list[dict]: + """ + List all tracked people, optionally filtered by tags or search term. + + Args: + tags: Filter to people with at least one of these tags + search: Search term to match against name, aliases, or notes + limit: Maximum number of results (default 50, max 200) + + Returns: + List of person records matching the filters + """ + logger.info(f"MCP: Listing people (tags={tags}, search={search})") + + limit = min(limit, 200) + + with make_session() as session: + query = session.query(Person) + + if tags: + query = query.filter( + Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))) + ) + + if search: + search_term = f"%{search.lower()}%" + query = query.filter( + (Person.display_name.ilike(search_term)) + | (Person.content.ilike(search_term)) + | (Person.identifier.ilike(search_term)) + ) + + query = query.order_by(Person.display_name).limit(limit) + people = query.all() + + return [_person_to_dict(p) for p in people] + + +@mcp.tool() +async def delete_person(identifier: str) -> dict: + """ + Delete a person by their identifier. + + This permanently removes the person and all associated data. + Observations about this person (with subject "person:") will remain. + + Args: + identifier: The person's unique identifier + + Returns: + Confirmation of deletion + """ + logger.info(f"MCP: Deleting person: {identifier}") + + with make_session() as session: + person = session.query(Person).filter(Person.identifier == identifier).first() + if not person: + raise ValueError(f"Person with identifier '{identifier}' not found") + + display_name = person.display_name + session.delete(person) + session.commit() + + return { + "deleted": True, + "identifier": identifier, + "display_name": display_name, + } diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index f4c2c1f..9e76358 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -16,6 +16,7 @@ SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls" DISCORD_ROOT = "memory.workers.tasks.discord" BACKUP_ROOT = "memory.workers.tasks.backup" GITHUB_ROOT = "memory.workers.tasks.github" +PEOPLE_ROOT = "memory.workers.tasks.people" ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message" EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message" PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message" @@ -67,6 +68,10 @@ SYNC_GITHUB_REPO = f"{GITHUB_ROOT}.sync_github_repo" SYNC_ALL_GITHUB_REPOS = f"{GITHUB_ROOT}.sync_all_github_repos" SYNC_GITHUB_ITEM = f"{GITHUB_ROOT}.sync_github_item" +# People tasks +SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person" +UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person" + def get_broker_url() -> str: protocol = settings.CELERY_BROKER_TYPE @@ -123,6 +128,7 @@ app.conf.update( }, f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"}, f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"}, + f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"}, }, beat_schedule={ "sync-github-repos-hourly": { diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index 2d0966c..f19da3e 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -17,6 +17,7 @@ from memory.common.db.models.source_items import ( BookSection, ForumPost, GithubItem, + GithubPRData, GitCommit, Photo, MiscDoc, @@ -46,6 +47,10 @@ from memory.common.db.models.observations import ( BeliefCluster, ConversationMetrics, ) +from memory.common.db.models.people import ( + Person, + PersonPayload, +) from memory.common.db.models.sources import ( Book, ArticleFeed, @@ -77,6 +82,7 @@ Payload = ( | ForumPostPayload | EmailAttachmentPayload | MailMessagePayload + | PersonPayload ) __all__ = [ @@ -95,6 +101,7 @@ __all__ = [ "BookSection", "ForumPost", "GithubItem", + "GithubPRData", "GitCommit", "Photo", "MiscDoc", @@ -105,6 +112,9 @@ __all__ = [ "ObservationPattern", "BeliefCluster", "ConversationMetrics", + # People + "Person", + "PersonPayload", # Sources "Book", "ArticleFeed", diff --git a/src/memory/common/db/models/people.py b/src/memory/common/db/models/people.py new file mode 100644 index 0000000..535016b --- /dev/null +++ b/src/memory/common/db/models/people.py @@ -0,0 +1,95 @@ +""" +Database models for tracking people. +""" + +from typing import Annotated, Sequence, cast + +from sqlalchemy import ( + ARRAY, + BigInteger, + Column, + ForeignKey, + Index, + Text, +) +from sqlalchemy.dialects.postgresql import JSONB + +import memory.common.extract as extract + +from memory.common.db.models.source_item import ( + SourceItem, + SourceItemPayload, +) + + +class PersonPayload(SourceItemPayload): + identifier: Annotated[str, "Unique identifier/slug for the person"] + display_name: Annotated[str, "Display name of the person"] + aliases: Annotated[list[str], "Alternative names/handles for the person"] + contact_info: Annotated[dict, "Contact information (email, phone, etc.)"] + + +class Person(SourceItem): + """A person you know or want to track.""" + + __tablename__ = "people" + + id = Column( + BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True + ) + identifier = Column(Text, unique=True, nullable=False, index=True) + display_name = Column(Text, nullable=False) + aliases = Column(ARRAY(Text), server_default="{}", nullable=False) + contact_info = Column(JSONB, server_default="{}", nullable=False) + + __mapper_args__ = { + "polymorphic_identity": "person", + } + + __table_args__ = ( + Index("person_identifier_idx", "identifier"), + Index("person_display_name_idx", "display_name"), + Index("person_aliases_idx", "aliases", postgresql_using="gin"), + ) + + def as_payload(self) -> PersonPayload: + return PersonPayload( + **super().as_payload(), + identifier=cast(str, self.identifier), + display_name=cast(str, self.display_name), + aliases=cast(list[str], self.aliases) or [], + contact_info=cast(dict, self.contact_info) or {}, + ) + + @property + def display_contents(self) -> dict: + return { + "identifier": self.identifier, + "display_name": self.display_name, + "aliases": self.aliases, + "contact_info": self.contact_info, + "notes": self.content, + "tags": self.tags, + } + + def _chunk_contents(self) -> Sequence[extract.DataChunk]: + """Create searchable chunks from person data.""" + parts = [f"# {self.display_name}"] + + if self.aliases: + aliases_str = ", ".join(cast(list[str], self.aliases)) + parts.append(f"Also known as: {aliases_str}") + + if self.tags: + tags_str = ", ".join(cast(list[str], self.tags)) + parts.append(f"Tags: {tags_str}") + + if self.content: + parts.append(f"\n{self.content}") + + text = "\n".join(parts) + return extract.extract_text(text, modality="person") + + @classmethod + def get_collections(cls) -> list[str]: + return ["person"] diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index 5ac4bc5..58ec143 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -9,15 +9,19 @@ from collections.abc import Collection from typing import Any, Annotated, Sequence, cast from PIL import Image +import zlib + from sqlalchemy import ( ARRAY, BigInteger, + Boolean, CheckConstraint, Column, DateTime, ForeignKey, Index, Integer, + LargeBinary, Numeric, Text, func, @@ -30,6 +34,7 @@ import memory.common.extract as extract import memory.common.summarizer as summarizer import memory.common.formatters.observation as observation +from memory.common.db.models.base import Base from memory.common.db.models.source_item import ( SourceItem, SourceItemPayload, @@ -858,6 +863,14 @@ class GithubItem(SourceItem): milestone = Column(Text, nullable=True) comment_count = Column(Integer, nullable=True) + # Relationship to PR-specific data + pr_data = relationship( + "GithubPRData", + back_populates="github_item", + uselist=False, + cascade="all, delete-orphan", + ) + __mapper_args__ = { "polymorphic_identity": "github_item", } @@ -902,6 +915,59 @@ class GithubItem(SourceItem): return [] +class GithubPRData(Base): + """PR-specific data linked to GithubItem. Not a SourceItem - not indexed separately.""" + + __tablename__ = "github_pr_data" + + id = Column(BigInteger, primary_key=True) + github_item_id = Column( + BigInteger, + ForeignKey("github_item.id", ondelete="CASCADE"), + unique=True, + nullable=False, + index=True, + ) + + # Diff (compressed with zlib) + diff_compressed = Column(LargeBinary, nullable=True) + + # File changes as structured data + # [{filename, status, additions, deletions, patch?}] + files = Column(JSONB, nullable=True) + + # Stats + additions = Column(Integer, nullable=True) + deletions = Column(Integer, nullable=True) + changed_files_count = Column(Integer, nullable=True) + + # Reviews (structured) + # [{user, state, body, submitted_at}] + reviews = Column(JSONB, nullable=True) + + # Review comments (line-by-line code comments) + # [{user, body, path, line, diff_hunk, created_at}] + review_comments = Column(JSONB, nullable=True) + + # Relationship back to GithubItem + github_item = relationship("GithubItem", back_populates="pr_data") + + @property + def diff(self) -> str | None: + """Decompress and return the full diff text.""" + if self.diff_compressed: + return zlib.decompress(self.diff_compressed).decode("utf-8") + return None + + @diff.setter + def diff(self, value: str | None) -> None: + """Compress and store the diff text.""" + if value: + self.diff_compressed = zlib.compress(value.encode("utf-8")) + else: + self.diff_compressed = None + + class NotePayload(SourceItemPayload): note_type: Annotated[str | None, "Category of the note"] subject: Annotated[str | None, "What the note is about"] diff --git a/src/memory/parsers/github.py b/src/memory/parsers/github.py index ae0c8fc..e5ab166 100644 --- a/src/memory/parsers/github.py +++ b/src/memory/parsers/github.py @@ -42,6 +42,51 @@ class GithubComment(TypedDict): updated_at: str +class GithubReviewComment(TypedDict): + """A line-by-line code review comment on a PR.""" + + id: int + user: str + body: str + path: str + line: int | None + side: str # "LEFT" or "RIGHT" + diff_hunk: str + created_at: str + + +class GithubReview(TypedDict): + """A PR review (approval, request changes, etc.).""" + + id: int + user: str + state: str # "approved", "changes_requested", "commented", "dismissed" + body: str | None + submitted_at: str + + +class GithubFileChange(TypedDict): + """A file changed in a PR.""" + + filename: str + status: str # "added", "modified", "removed", "renamed" + additions: int + deletions: int + patch: str | None # Diff patch for this file + + +class GithubPRDataDict(TypedDict): + """PR-specific data for storage in GithubPRData model.""" + + diff: str | None # Full diff text + files: list[GithubFileChange] + additions: int + deletions: int + changed_files_count: int + reviews: list[GithubReview] + review_comments: list[GithubReviewComment] + + class GithubIssueData(TypedDict): """Parsed issue/PR data ready for storage.""" @@ -60,9 +105,11 @@ class GithubIssueData(TypedDict): github_updated_at: datetime comment_count: int comments: list[GithubComment] - diff_summary: str | None # PRs only + diff_summary: str | None # PRs only (truncated, for backward compat) project_fields: dict[str, Any] | None content_hash: str + # PR-specific extended data (None for issues) + pr_data: GithubPRDataDict | None def parse_github_date(date_str: str | None) -> datetime | None: @@ -267,6 +314,145 @@ class GithubClient: return comments + def fetch_review_comments( + self, + owner: str, + repo: str, + pr_number: int, + ) -> list[GithubReviewComment]: + """Fetch all line-by-line review comments for a PR.""" + comments: list[GithubReviewComment] = [] + page = 1 + + while True: + response = self.session.get( + f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/comments", + params={"page": page, "per_page": 100}, + timeout=30, + ) + response.raise_for_status() + self._handle_rate_limit(response) + + page_comments = response.json() + if not page_comments: + break + + comments.extend( + [ + GithubReviewComment( + id=c["id"], + user=c["user"]["login"] if c.get("user") else "ghost", + body=c.get("body", ""), + path=c.get("path", ""), + line=c.get("line"), + side=c.get("side", "RIGHT"), + diff_hunk=c.get("diff_hunk", ""), + created_at=c["created_at"], + ) + for c in page_comments + ] + ) + page += 1 + + return comments + + def fetch_reviews( + self, + owner: str, + repo: str, + pr_number: int, + ) -> list[GithubReview]: + """Fetch all reviews (approvals, change requests) for a PR.""" + reviews: list[GithubReview] = [] + page = 1 + + while True: + response = self.session.get( + f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/reviews", + params={"page": page, "per_page": 100}, + timeout=30, + ) + response.raise_for_status() + self._handle_rate_limit(response) + + page_reviews = response.json() + if not page_reviews: + break + + reviews.extend( + [ + GithubReview( + id=r["id"], + user=r["user"]["login"] if r.get("user") else "ghost", + state=r.get("state", "COMMENTED").lower(), + body=r.get("body"), + submitted_at=r.get("submitted_at", ""), + ) + for r in page_reviews + ] + ) + page += 1 + + return reviews + + def fetch_pr_files( + self, + owner: str, + repo: str, + pr_number: int, + ) -> list[GithubFileChange]: + """Fetch list of files changed in a PR with patches.""" + files: list[GithubFileChange] = [] + page = 1 + + while True: + response = self.session.get( + f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/files", + params={"page": page, "per_page": 100}, + timeout=30, + ) + response.raise_for_status() + self._handle_rate_limit(response) + + page_files = response.json() + if not page_files: + break + + files.extend( + [ + GithubFileChange( + filename=f["filename"], + status=f.get("status", "modified"), + additions=f.get("additions", 0), + deletions=f.get("deletions", 0), + patch=f.get("patch"), # May be None for binary files + ) + for f in page_files + ] + ) + page += 1 + + return files + + def fetch_pr_diff( + self, + owner: str, + repo: str, + pr_number: int, + ) -> str | None: + """Fetch the full diff for a PR (not truncated).""" + try: + response = self.session.get( + f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}", + headers={"Accept": "application/vnd.github.diff"}, + timeout=60, # Longer timeout for large diffs + ) + if response.ok: + return response.text + except Exception as e: + logger.warning(f"Failed to fetch PR diff: {e}") + return None + def fetch_project_fields( self, owner: str, @@ -490,28 +676,44 @@ class GithubClient: diff_summary=None, project_fields=None, # Fetched separately if enabled content_hash=compute_content_hash(body, comments), + pr_data=None, # Issues don't have PR data ) def _parse_pr( self, owner: str, repo: str, pr: dict[str, Any] ) -> GithubIssueData: """Parse raw PR data into structured format.""" - comments = self.fetch_comments(owner, repo, pr["number"]) + pr_number = pr["number"] + comments = self.fetch_comments(owner, repo, pr_number) body = pr.get("body") or "" - # Get diff summary (truncated) - diff_summary = None - if diff_url := pr.get("diff_url"): - try: - diff_response = self.session.get(diff_url, timeout=30) - if diff_response.ok: - diff_summary = diff_response.text[:5000] # Truncate large diffs - except Exception as e: - logger.warning(f"Failed to fetch diff: {e}") + # Fetch PR-specific data + review_comments = self.fetch_review_comments(owner, repo, pr_number) + reviews = self.fetch_reviews(owner, repo, pr_number) + files = self.fetch_pr_files(owner, repo, pr_number) + full_diff = self.fetch_pr_diff(owner, repo, pr_number) + + # Calculate stats from files + additions = sum(f["additions"] for f in files) + deletions = sum(f["deletions"] for f in files) + + # Get diff summary (truncated, for backward compatibility) + diff_summary = full_diff[:5000] if full_diff else None + + # Build PR data dict + pr_data = GithubPRDataDict( + diff=full_diff, + files=files, + additions=additions, + deletions=deletions, + changed_files_count=len(files), + reviews=reviews, + review_comments=review_comments, + ) return GithubIssueData( kind="pr", - number=pr["number"], + number=pr_number, title=pr["title"], body=body, state=pr["state"], @@ -528,4 +730,5 @@ class GithubClient: diff_summary=diff_summary, project_fields=None, # Fetched separately if enabled content_hash=compute_content_hash(body, comments), + pr_data=pr_data, ) diff --git a/src/memory/workers/tasks/github.py b/src/memory/workers/tasks/github.py index 2c69064..6913134 100644 --- a/src/memory/workers/tasks/github.py +++ b/src/memory/workers/tasks/github.py @@ -12,12 +12,13 @@ from memory.common.celery_app import ( SYNC_GITHUB_ITEM, ) from memory.common.db.connection import make_session -from memory.common.db.models import GithubItem +from memory.common.db.models import GithubItem, GithubPRData from memory.common.db.models.sources import GithubAccount, GithubRepo from memory.parsers.github import ( GithubClient, GithubCredentials, GithubIssueData, + GithubPRDataDict, ) from memory.workers.tasks.content_processing import ( create_content_hash, @@ -32,11 +33,42 @@ logger = logging.getLogger(__name__) def _build_content(issue_data: GithubIssueData) -> str: """Build searchable content from issue/PR data.""" content_parts = [f"# {issue_data['title']}", issue_data["body"]] + + # Add regular comments for comment in issue_data["comments"]: content_parts.append(f"\n---\n**{comment['author']}**: {comment['body']}") + + # Add review comments for PRs (makes them searchable) + pr_data = issue_data.get("pr_data") + if pr_data and pr_data.get("review_comments"): + content_parts.append("\n---\n## Code Review Comments\n") + for rc in pr_data["review_comments"]: + content_parts.append( + f"**{rc['user']}** on `{rc['path']}`: {rc['body']}" + ) + return "\n\n".join(content_parts) +def _create_pr_data(issue_data: GithubIssueData) -> GithubPRData | None: + """Create GithubPRData from PR-specific data if available.""" + pr_data_dict = issue_data.get("pr_data") + if not pr_data_dict: + return None + + pr_data = GithubPRData( + additions=pr_data_dict.get("additions"), + deletions=pr_data_dict.get("deletions"), + changed_files_count=pr_data_dict.get("changed_files_count"), + files=pr_data_dict.get("files"), + reviews=pr_data_dict.get("reviews"), + review_comments=pr_data_dict.get("review_comments"), + ) + # Use the setter to compress the diff + pr_data.diff = pr_data_dict.get("diff") + return pr_data + + def _create_github_item( repo: GithubRepo, issue_data: GithubIssueData, @@ -57,7 +89,7 @@ def _create_github_item( repo_tags = cast(list[str], repo.tags) or [] - return GithubItem( + github_item = GithubItem( modality="github", sha256=create_content_hash(content), content=content, @@ -86,6 +118,12 @@ def _create_github_item( mime_type="text/markdown", ) + # Create PR data if this is a PR + if issue_data["kind"] == "pr": + github_item.pr_data = _create_pr_data(issue_data) + + return github_item + def _needs_reindex(existing: GithubItem, new_data: GithubIssueData) -> bool: """Check if an existing item needs reindexing based on content changes.""" @@ -165,6 +203,23 @@ def _update_existing_item( repo_tags = cast(list[str], repo.tags) or [] existing.tags = repo_tags + issue_data["labels"] # type: ignore + # Update PR data if this is a PR + if issue_data["kind"] == "pr": + pr_data_dict = issue_data.get("pr_data") + if pr_data_dict: + if existing.pr_data: + # Update existing pr_data + existing.pr_data.additions = pr_data_dict.get("additions") + existing.pr_data.deletions = pr_data_dict.get("deletions") + existing.pr_data.changed_files_count = pr_data_dict.get("changed_files_count") + existing.pr_data.files = pr_data_dict.get("files") + existing.pr_data.reviews = pr_data_dict.get("reviews") + existing.pr_data.review_comments = pr_data_dict.get("review_comments") + existing.pr_data.diff = pr_data_dict.get("diff") + else: + # Create new pr_data + existing.pr_data = _create_pr_data(issue_data) + session.flush() # Re-embed and push to Qdrant @@ -193,6 +248,8 @@ def _serialize_issue_data(data: GithubIssueData) -> dict[str, Any]: } for c in data["comments"] ], + # pr_data is already JSON-serializable (TypedDict) + "pr_data": data.get("pr_data"), } @@ -200,6 +257,11 @@ def _deserialize_issue_data(data: dict[str, Any]) -> GithubIssueData: """Deserialize issue data from Celery task.""" from memory.parsers.github import parse_github_date + # Reconstruct pr_data if present + pr_data: GithubPRDataDict | None = None + if data.get("pr_data"): + pr_data = cast(GithubPRDataDict, data["pr_data"]) + return GithubIssueData( kind=data["kind"], number=data["number"], @@ -219,6 +281,7 @@ def _deserialize_issue_data(data: dict[str, Any]) -> GithubIssueData: diff_summary=data.get("diff_summary"), project_fields=data.get("project_fields"), content_hash=data["content_hash"], + pr_data=pr_data, ) diff --git a/src/memory/workers/tasks/people.py b/src/memory/workers/tasks/people.py new file mode 100644 index 0000000..fbeb463 --- /dev/null +++ b/src/memory/workers/tasks/people.py @@ -0,0 +1,145 @@ +""" +Celery tasks for tracking people. +""" + +import logging + +from memory.common.db.connection import make_session +from memory.common.db.models import Person +from memory.common.celery_app import app, SYNC_PERSON, UPDATE_PERSON +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + create_task_result, + process_content_item, + safe_task_execution, +) + +logger = logging.getLogger(__name__) + + +def _deep_merge(base: dict, updates: dict) -> dict: + """Deep merge two dictionaries, with updates taking precedence.""" + result = dict(base) + for key, value in updates.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + return result + + +@app.task(name=SYNC_PERSON) +@safe_task_execution +def sync_person( + identifier: str, + display_name: str, + aliases: list[str] | None = None, + contact_info: dict | None = None, + tags: list[str] | None = None, + notes: str | None = None, +): + """ + Create or update a person in the knowledge base. + + Args: + identifier: Unique slug for the person + display_name: Human-readable name + aliases: Alternative names/handles + contact_info: Contact information dict + tags: Categorization tags + notes: Free-form notes about the person + """ + logger.info(f"Syncing person: {identifier}") + + # Create hash from identifier for deduplication + sha256 = create_content_hash(f"person:{identifier}") + + with make_session() as session: + # Check if person already exists by identifier + existing = session.query(Person).filter(Person.identifier == identifier).first() + + if existing: + logger.info(f"Person already exists: {identifier}") + return create_task_result(existing, "already_exists") + + # Also check by sha256 (defensive) + existing_by_hash = check_content_exists(session, Person, sha256=sha256) + if existing_by_hash: + logger.info(f"Person already exists (by hash): {identifier}") + return create_task_result(existing_by_hash, "already_exists") + + person = Person( + identifier=identifier, + display_name=display_name, + aliases=aliases or [], + contact_info=contact_info or {}, + tags=tags or [], + content=notes, + modality="person", + mime_type="text/plain", + sha256=sha256, + size=len(notes or ""), + ) + + return process_content_item(person, session) + + +@app.task(name=UPDATE_PERSON) +@safe_task_execution +def update_person( + identifier: str, + display_name: str | None = None, + aliases: list[str] | None = None, + contact_info: dict | None = None, + tags: list[str] | None = None, + notes: str | None = None, + replace_notes: bool = False, +): + """ + Update a person with merge semantics. + + Merge behavior: + - display_name: Replaces if provided + - aliases: Union with existing + - contact_info: Deep merge with existing + - tags: Union with existing + - notes: Append to existing (or replace if replace_notes=True) + """ + logger.info(f"Updating person: {identifier}") + + with make_session() as session: + person = session.query(Person).filter(Person.identifier == identifier).first() + if not person: + logger.warning(f"Person not found: {identifier}") + return {"status": "not_found", "identifier": identifier} + + if display_name is not None: + person.display_name = display_name + + if aliases is not None: + existing_aliases = set(person.aliases or []) + new_aliases = existing_aliases | set(aliases) + person.aliases = list(new_aliases) + + if contact_info is not None: + existing_contact = dict(person.contact_info or {}) + person.contact_info = _deep_merge(existing_contact, contact_info) + + if tags is not None: + existing_tags = set(person.tags or []) + new_tags = existing_tags | set(tags) + person.tags = list(new_tags) + + if notes is not None: + if replace_notes or not person.content: + person.content = notes + else: + person.content = f"{person.content}\n\n---\n\n{notes}" + + # Update hash based on new content + person.sha256 = create_content_hash(f"person:{identifier}") + person.size = len(person.content or "") + person.embed_status = "RAW" # Re-embed with updated content + + return process_content_item(person, session) diff --git a/tests/memory/api/MCP/__init__.py b/tests/memory/api/MCP/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/memory/api/MCP/test_github.py b/tests/memory/api/MCP/test_github.py new file mode 100644 index 0000000..ee42f6b --- /dev/null +++ b/tests/memory/api/MCP/test_github.py @@ -0,0 +1,1064 @@ +"""Comprehensive tests for GitHub MCP tools.""" + +import sys +import pytest +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, Mock, MagicMock, patch + +# Mock the mcp module and all its submodules before importing anything that uses it +_mock_mcp = MagicMock() +_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +sys.modules["mcp"] = _mock_mcp +sys.modules["mcp.server"] = MagicMock() +sys.modules["mcp.server.auth"] = MagicMock() +sys.modules["mcp.server.auth.handlers"] = MagicMock() +sys.modules["mcp.server.auth.handlers.authorize"] = MagicMock() +sys.modules["mcp.server.auth.handlers.token"] = MagicMock() +sys.modules["mcp.server.auth.provider"] = MagicMock() +sys.modules["mcp.server.fastmcp"] = MagicMock() +sys.modules["mcp.server.fastmcp.server"] = MagicMock() + +# Also mock the memory.api.MCP.base module to avoid MCP imports +_mock_base = MagicMock() +_mock_base.mcp = MagicMock() +_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +sys.modules["memory.api.MCP.base"] = _mock_base + +from memory.common.db.models import GithubItem +from memory.common.db import connection as db_connection + + +@pytest.fixture(autouse=True) +def reset_db_cache(): + """Reset the cached database engine between tests.""" + db_connection._engine = None + db_connection._session_factory = None + db_connection._scoped_session = None + yield + db_connection._engine = None + db_connection._session_factory = None + db_connection._scoped_session = None + + +def _make_sha256(content: str) -> bytes: + """Generate a sha256 hash for test content.""" + import hashlib + return hashlib.sha256(content.encode()).digest() + + +@pytest.fixture +def sample_issues(db_session): + """Create sample GitHub issues for testing.""" + now = datetime.now(timezone.utc) + issues = [ + GithubItem( + kind="issue", + repo_path="owner/repo1", + number=1, + title="Fix authentication bug", + content="There is a bug in the authentication system.\n\n## Comments\n\n**user1**: I can reproduce this.", + state="open", + author="alice", + labels=["bug", "security"], + assignees=["bob", "charlie"], + milestone="v1.0", + project_status="In Progress", + project_priority="High", + project_fields={ + "EquiStamp.Client": "Redwood", + "EquiStamp.Status": "In Progress", + "EquiStamp.Task Type": "Bug Fix", + }, + comment_count=1, + created_at=now - timedelta(days=10), + github_updated_at=now - timedelta(days=1), + modality="github", + sha256=_make_sha256("issue-1-content"), + ), + GithubItem( + kind="issue", + repo_path="owner/repo1", + number=2, + title="Add dark mode support", + content="Users want dark mode for the application.", + state="open", + author="bob", + labels=["enhancement", "ui"], + assignees=["alice"], + milestone="v2.0", + project_status="Backlog", + project_priority="Medium", + project_fields={ + "EquiStamp.Client": "University of Illinois", + "EquiStamp.Status": "Backlog", + "EquiStamp.Task Type": "Feature", + }, + comment_count=0, + created_at=now - timedelta(days=5), + github_updated_at=now - timedelta(days=2), + modality="github", + sha256=_make_sha256("issue-2-content"), + ), + GithubItem( + kind="issue", + repo_path="owner/repo2", + number=10, + title="Database migration issue", + content="Migration fails on PostgreSQL 15.", + state="closed", + author="charlie", + labels=["bug"], + assignees=["alice"], + milestone=None, + project_status="Closed", + project_priority=None, + project_fields={ + "EquiStamp.Client": "Redwood", + "EquiStamp.Status": "Closed", + }, + comment_count=3, + created_at=now - timedelta(days=20), + closed_at=now - timedelta(days=3), + github_updated_at=now - timedelta(days=3), + modality="github", + sha256=_make_sha256("issue-10-content"), + ), + GithubItem( + kind="pr", + repo_path="owner/repo1", + number=50, + title="Refactor user service", + content="This PR refactors the user service for better performance.", + state="merged", + author="alice", + labels=["refactor"], + assignees=["bob"], + milestone="v1.0", + project_status="Approved for Payment", + project_priority="Low", + project_fields={ + "EquiStamp.Client": "Redwood", + "EquiStamp.Status": "Approved for Payment", + "EquiStamp.Hours taken": "5", + }, + comment_count=2, + created_at=now - timedelta(days=15), + merged_at=now - timedelta(days=7), + github_updated_at=now - timedelta(days=7), + modality="github", + sha256=_make_sha256("pr-50-content"), + ), + GithubItem( + kind="issue", + repo_path="owner/repo1", + number=100, + title="Stale issue without updates", + content="This issue has not been updated in a long time.", + state="open", + author="dave", + labels=["stale"], + assignees=[], + milestone=None, + project_status=None, + project_priority=None, + project_fields=None, + comment_count=0, + created_at=now - timedelta(days=60), + github_updated_at=now - timedelta(days=45), + modality="github", + sha256=_make_sha256("issue-100-content"), + ), + ] + + for issue in issues: + db_session.add(issue) + db_session.commit() + + # Refresh to get IDs + for issue in issues: + db_session.refresh(issue) + + return issues + + +# ============================================================================= +# Tests for list_github_issues +# ============================================================================= + + +@pytest.mark.asyncio +async def test_list_github_issues_no_filters(db_session, sample_issues): + """Test listing all issues without filters.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues() + + # Should return all issues and PRs (not comments) + assert len(results) == 5 + # Should be ordered by github_updated_at desc + assert results[0]["number"] == 1 # Most recently updated + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_repo(db_session, sample_issues): + """Test filtering by repository.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(repo="owner/repo1") + + assert len(results) == 4 + assert all(r["repo_path"] == "owner/repo1" for r in results) + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_assignee(db_session, sample_issues): + """Test filtering by assignee.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(assignee="alice") + + assert len(results) == 2 + assert all("alice" in r["assignees"] for r in results) + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_author(db_session, sample_issues): + """Test filtering by author.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(author="alice") + + assert len(results) == 2 + assert all(r["author"] == "alice" for r in results) + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_state(db_session, sample_issues): + """Test filtering by state.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + open_results = await list_github_issues(state="open") + closed_results = await list_github_issues(state="closed") + merged_results = await list_github_issues(state="merged") + + assert len(open_results) == 3 + assert all(r["state"] == "open" for r in open_results) + + assert len(closed_results) == 1 + assert closed_results[0]["state"] == "closed" + + assert len(merged_results) == 1 + assert merged_results[0]["state"] == "merged" + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_kind(db_session, sample_issues): + """Test filtering by kind (issue vs PR).""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + issues = await list_github_issues(kind="issue") + prs = await list_github_issues(kind="pr") + + assert len(issues) == 4 + assert all(r["kind"] == "issue" for r in issues) + + assert len(prs) == 1 + assert prs[0]["kind"] == "pr" + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_labels(db_session, sample_issues): + """Test filtering by labels.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(labels=["bug"]) + + assert len(results) == 2 + assert all("bug" in r["labels"] for r in results) + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_project_status(db_session, sample_issues): + """Test filtering by project status.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(project_status="In Progress") + + assert len(results) == 1 + assert results[0]["project_status"] == "In Progress" + assert results[0]["number"] == 1 + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_project_field(db_session, sample_issues): + """Test filtering by project field (JSONB).""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues( + project_field={"EquiStamp.Client": "Redwood"} + ) + + assert len(results) == 3 + assert all( + r["project_fields"].get("EquiStamp.Client") == "Redwood" for r in results + ) + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_updated_since(db_session, sample_issues): + """Test filtering by updated_since.""" + from memory.api.MCP.github import list_github_issues + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=2)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(updated_since=since) + + # Only issue #1 was updated within last 2 days + assert len(results) == 1 + assert results[0]["number"] == 1 + + +@pytest.mark.asyncio +async def test_list_github_issues_filter_by_updated_before(db_session, sample_issues): + """Test filtering by updated_before (stale issues).""" + from memory.api.MCP.github import list_github_issues + + now = datetime.now(timezone.utc) + before = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(updated_before=before) + + # Only issue #100 hasn't been updated in 30+ days + assert len(results) == 1 + assert results[0]["number"] == 100 + + +@pytest.mark.asyncio +async def test_list_github_issues_order_by_created(db_session, sample_issues): + """Test ordering by created date.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(order_by="created") + + # Should be ordered by created_at desc + assert results[0]["number"] == 2 # Most recently created + + +@pytest.mark.asyncio +async def test_list_github_issues_limit(db_session, sample_issues): + """Test limiting results.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(limit=2) + + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_list_github_issues_limit_max_enforced(db_session, sample_issues): + """Test that limit is capped at 200.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + # Request 500 but should be capped at 200 + results = await list_github_issues(limit=500) + + # We only have 5 issues, but the limit should be internally capped at 200 + assert len(results) <= 200 + + +@pytest.mark.asyncio +async def test_list_github_issues_combined_filters(db_session, sample_issues): + """Test combining multiple filters.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues( + repo="owner/repo1", + state="open", + kind="issue", + project_field={"EquiStamp.Client": "Redwood"}, + ) + + assert len(results) == 1 + assert results[0]["number"] == 1 + assert results[0]["title"] == "Fix authentication bug" + + +@pytest.mark.asyncio +async def test_list_github_issues_url_construction(db_session, sample_issues): + """Test that URLs are correctly constructed.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(kind="issue", limit=1) + + assert "url" in results[0] + assert "github.com" in results[0]["url"] + assert "/issues/" in results[0]["url"] + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + pr_results = await list_github_issues(kind="pr") + + assert "/pull/" in pr_results[0]["url"] + + +# ============================================================================= +# Tests for github_issue_details +# ============================================================================= + + +@pytest.mark.asyncio +async def test_github_issue_details_found(db_session, sample_issues): + """Test getting details for an existing issue.""" + from memory.api.MCP.github import github_issue_details + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_issue_details(repo="owner/repo1", number=1) + + assert result["number"] == 1 + assert result["title"] == "Fix authentication bug" + assert "content" in result + assert "authentication" in result["content"] + assert result["project_fields"]["EquiStamp.Client"] == "Redwood" + + +@pytest.mark.asyncio +async def test_github_issue_details_not_found(db_session, sample_issues): + """Test getting details for a non-existent issue.""" + from memory.api.MCP.github import github_issue_details + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + with pytest.raises(ValueError, match="not found"): + await github_issue_details(repo="owner/repo1", number=999) + + +@pytest.mark.asyncio +async def test_github_issue_details_pr(db_session, sample_issues): + """Test getting details for a PR.""" + from memory.api.MCP.github import github_issue_details + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_issue_details(repo="owner/repo1", number=50) + + assert result["kind"] == "pr" + assert result["state"] == "merged" + assert result["merged_at"] is not None + + +# ============================================================================= +# Tests for github_work_summary +# ============================================================================= + + +@pytest.mark.asyncio +async def test_github_work_summary_by_client(db_session, sample_issues): + """Test work summary grouped by client.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary(since=since, group_by="client") + + assert "period" in result + assert "summary" in result + assert result["group_by"] == "client" + + # Check Redwood group + redwood = next((g for g in result["summary"] if g["group"] == "Redwood"), None) + assert redwood is not None + assert redwood["total"] >= 1 + + +@pytest.mark.asyncio +async def test_github_work_summary_by_status(db_session, sample_issues): + """Test work summary grouped by status.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary(since=since, group_by="status") + + assert result["group_by"] == "status" + # Check that we have status groups + statuses = [g["group"] for g in result["summary"]] + assert any(s in statuses for s in ["In Progress", "Backlog", "Closed", "(unset)"]) + + +@pytest.mark.asyncio +async def test_github_work_summary_by_author(db_session, sample_issues): + """Test work summary grouped by author.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary(since=since, group_by="author") + + assert result["group_by"] == "author" + authors = [g["group"] for g in result["summary"]] + assert "alice" in authors + + +@pytest.mark.asyncio +async def test_github_work_summary_by_repo(db_session, sample_issues): + """Test work summary grouped by repository.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary(since=since, group_by="repo") + + assert result["group_by"] == "repo" + repos = [g["group"] for g in result["summary"]] + assert "owner/repo1" in repos + + +@pytest.mark.asyncio +async def test_github_work_summary_with_until(db_session, sample_issues): + """Test work summary with until date.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + until = (now - timedelta(days=5)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary(since=since, until=until) + + assert result["period"]["until"] is not None + + +@pytest.mark.asyncio +async def test_github_work_summary_with_repo_filter(db_session, sample_issues): + """Test work summary filtered by repository.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary( + since=since, group_by="client", repo="owner/repo1" + ) + + # Should only include items from repo1 + total = sum(g["total"] for g in result["summary"]) + assert total <= 4 # repo1 has 4 items + + +@pytest.mark.asyncio +async def test_github_work_summary_invalid_group_by(db_session, sample_issues): + """Test work summary with invalid group_by value.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + with pytest.raises(ValueError, match="Invalid group_by"): + await github_work_summary(since=since, group_by="invalid") + + +@pytest.mark.asyncio +async def test_github_work_summary_includes_sample_issues(db_session, sample_issues): + """Test that work summary includes sample issues.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary(since=since, group_by="client") + + for group in result["summary"]: + assert "issues" in group + if group["total"] > 0: + assert len(group["issues"]) <= 5 # Limited to 5 samples + for issue in group["issues"]: + assert "number" in issue + assert "title" in issue + assert "url" in issue + + +# ============================================================================= +# Tests for github_repo_overview +# ============================================================================= + + +@pytest.mark.asyncio +async def test_github_repo_overview_basic(db_session, sample_issues): + """Test basic repo overview.""" + from memory.api.MCP.github import github_repo_overview + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_repo_overview(repo="owner/repo1") + + assert result["repo_path"] == "owner/repo1" + assert "counts" in result + assert result["counts"]["total"] == 4 + assert result["counts"]["total_issues"] == 3 + assert result["counts"]["total_prs"] == 1 + + +@pytest.mark.asyncio +async def test_github_repo_overview_counts(db_session, sample_issues): + """Test repo overview counts are correct.""" + from memory.api.MCP.github import github_repo_overview + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_repo_overview(repo="owner/repo1") + + counts = result["counts"] + assert counts["open_issues"] == 3 # Issues 1, 2, 100 are all open in repo1 + assert counts["merged_prs"] == 1 + + +@pytest.mark.asyncio +async def test_github_repo_overview_status_breakdown(db_session, sample_issues): + """Test repo overview includes status breakdown.""" + from memory.api.MCP.github import github_repo_overview + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_repo_overview(repo="owner/repo1") + + assert "status_breakdown" in result + assert "In Progress" in result["status_breakdown"] + assert "Backlog" in result["status_breakdown"] + + +@pytest.mark.asyncio +async def test_github_repo_overview_top_assignees(db_session, sample_issues): + """Test repo overview includes top assignees.""" + from memory.api.MCP.github import github_repo_overview + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_repo_overview(repo="owner/repo1") + + assert "top_assignees" in result + assert isinstance(result["top_assignees"], list) + + +@pytest.mark.asyncio +async def test_github_repo_overview_labels(db_session, sample_issues): + """Test repo overview includes labels.""" + from memory.api.MCP.github import github_repo_overview + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_repo_overview(repo="owner/repo1") + + assert "labels" in result + assert "bug" in result["labels"] + assert "enhancement" in result["labels"] + + +@pytest.mark.asyncio +async def test_github_repo_overview_last_updated(db_session, sample_issues): + """Test repo overview includes last updated timestamp.""" + from memory.api.MCP.github import github_repo_overview + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_repo_overview(repo="owner/repo1") + + assert "last_updated" in result + assert result["last_updated"] is not None + + +@pytest.mark.asyncio +async def test_github_repo_overview_empty_repo(db_session): + """Test repo overview for a repo with no issues.""" + from memory.api.MCP.github import github_repo_overview + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_repo_overview(repo="nonexistent/repo") + + assert result["counts"]["total"] == 0 + + +# ============================================================================= +# Tests for search_github_issues +# ============================================================================= + + +@pytest.mark.asyncio +async def test_search_github_issues_basic(db_session, sample_issues): + """Test basic search functionality.""" + from memory.api.MCP.github import search_github_issues + + mock_search_result = Mock() + mock_search_result.id = sample_issues[0].id + mock_search_result.score = 0.95 + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [mock_search_result] + results = await search_github_issues(query="authentication bug") + + assert len(results) == 1 + assert "search_score" in results[0] + mock_search.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_github_issues_with_repo_filter(db_session, sample_issues): + """Test search with repository filter.""" + from memory.api.MCP.github import search_github_issues + + mock_search_result = Mock() + mock_search_result.id = sample_issues[0].id + mock_search_result.score = 0.85 + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [mock_search_result] + results = await search_github_issues( + query="authentication", repo="owner/repo1" + ) + + # Verify search was called with source_ids filter + mock_search.assert_called_once() + call_kwargs = mock_search.call_args[1] + assert "filters" in call_kwargs + + +@pytest.mark.asyncio +async def test_search_github_issues_with_state_filter(db_session, sample_issues): + """Test search with state filter.""" + from memory.api.MCP.github import search_github_issues + + mock_search_result = Mock() + mock_search_result.id = sample_issues[0].id + mock_search_result.score = 0.80 + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [mock_search_result] + results = await search_github_issues(query="bug", state="open") + + mock_search.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_github_issues_limit(db_session, sample_issues): + """Test search respects limit.""" + from memory.api.MCP.github import search_github_issues + + mock_results = [Mock(id=issue.id, score=0.9 - i * 0.1) for i, issue in enumerate(sample_issues[:3])] + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = mock_results + results = await search_github_issues(query="test", limit=2) + + # The search function should have been called with limit=2 in config + call_kwargs = mock_search.call_args[1] + assert call_kwargs["config"].limit == 2 + + +@pytest.mark.asyncio +async def test_search_github_issues_uses_github_modality(db_session, sample_issues): + """Test that search uses github modality.""" + from memory.api.MCP.github import search_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [] + await search_github_issues(query="test") + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["modalities"] == {"github"} + + +# ============================================================================= +# Tests for helper functions +# ============================================================================= + + +def test_build_github_url_issue(): + """Test URL construction for issues.""" + from memory.api.MCP.github import _build_github_url + + url = _build_github_url("owner/repo", 123, "issue") + assert url == "https://github.com/owner/repo/issues/123" + + +def test_build_github_url_pr(): + """Test URL construction for PRs.""" + from memory.api.MCP.github import _build_github_url + + url = _build_github_url("owner/repo", 456, "pr") + assert url == "https://github.com/owner/repo/pull/456" + + +def test_build_github_url_no_number(): + """Test URL construction without number.""" + from memory.api.MCP.github import _build_github_url + + url = _build_github_url("owner/repo", None, "issue") + assert url == "https://github.com/owner/repo" + + +def test_serialize_issue_basic(db_session, sample_issues): + """Test issue serialization.""" + from memory.api.MCP.github import _serialize_issue + + issue = sample_issues[0] + result = _serialize_issue(issue) + + assert result["id"] == issue.id + assert result["number"] == issue.number + assert result["title"] == issue.title + assert result["state"] == issue.state + assert result["author"] == issue.author + assert result["assignees"] == issue.assignees + assert result["labels"] == issue.labels + assert "content" not in result + + +def test_serialize_issue_with_content(db_session, sample_issues): + """Test issue serialization with content.""" + from memory.api.MCP.github import _serialize_issue + + issue = sample_issues[0] + result = _serialize_issue(issue, include_content=True) + + assert "content" in result + assert result["content"] == issue.content + + +# ============================================================================= +# Parametrized tests +# ============================================================================= + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "order_by,expected_first_number", + [ + ("updated", 1), # Most recently updated + ("created", 2), # Most recently created (among repo1) + ("number", 100), # Highest number + ], +) +async def test_list_github_issues_ordering( + db_session, sample_issues, order_by, expected_first_number +): + """Test different ordering options.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(order_by=order_by) + + assert results[0]["number"] == expected_first_number + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "group_by", + ["client", "status", "author", "repo", "task_type"], +) +async def test_github_work_summary_all_group_by_options(db_session, sample_issues, group_by): + """Test all valid group_by options.""" + from memory.api.MCP.github import github_work_summary + + now = datetime.now(timezone.utc) + since = (now - timedelta(days=30)).isoformat() + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_work_summary(since=since, group_by=group_by) + + assert result["group_by"] == group_by + assert "summary" in result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "labels,expected_count", + [ + (["bug"], 2), + (["enhancement"], 1), + (["bug", "security"], 1), # Only issue 1 has both + (["nonexistent"], 0), + ], +) +async def test_list_github_issues_label_filtering( + db_session, sample_issues, labels, expected_count +): + """Test various label filtering scenarios.""" + from memory.api.MCP.github import list_github_issues + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + results = await list_github_issues(labels=labels) + + # Note: label filtering uses ANY match, so ["bug", "security"] matches + # anything with "bug" OR "security" + assert len(results) >= expected_count + + +# ============================================================================= +# Tests for GithubPRData model and PR-specific functionality +# ============================================================================= + + +def test_github_pr_data_diff_compression(): + """Test that GithubPRData compresses and decompresses diffs correctly.""" + from memory.common.db.models import GithubPRData + + pr_data = GithubPRData() + test_diff = """diff --git a/file.py b/file.py +index 123..456 789 +--- a/file.py ++++ b/file.py +@@ -1,3 +1,4 @@ + def hello(): + print("Hello") ++ print("World") +""" + + # Set diff via property (should compress) + pr_data.diff = test_diff + assert pr_data.diff_compressed is not None + assert len(pr_data.diff_compressed) < len(test_diff.encode("utf-8")) + + # Get diff via property (should decompress) + assert pr_data.diff == test_diff + + +def test_github_pr_data_diff_none(): + """Test GithubPRData handles None diff correctly.""" + from memory.common.db.models import GithubPRData + + pr_data = GithubPRData() + assert pr_data.diff is None + + pr_data.diff = None + assert pr_data.diff_compressed is None + assert pr_data.diff is None + + +@pytest.fixture +def sample_pr_with_data(db_session): + """Create a sample PR with GithubPRData attached.""" + from memory.common.db.models import GithubItem, GithubPRData + + now = datetime.now(timezone.utc) + + pr = GithubItem( + kind="pr", + repo_path="owner/repo1", + number=999, + title="Test PR with data", + content="This PR has full PR data attached.", + state="open", + author="alice", + labels=["feature"], + assignees=["alice"], + milestone=None, + project_status="In Progress", + project_priority=None, + project_fields={"EquiStamp.Client": "Test"}, + comment_count=1, + created_at=now - timedelta(days=1), + github_updated_at=now, + modality="github", + sha256=_make_sha256("pr-999-content"), + ) + + pr_data = GithubPRData( + additions=50, + deletions=10, + changed_files_count=3, + files=[ + {"filename": "src/main.py", "status": "modified", "additions": 30, "deletions": 5, "patch": "@@ -1,3 +1,4 @@"}, + {"filename": "tests/test_main.py", "status": "added", "additions": 20, "deletions": 0, "patch": None}, + {"filename": "README.md", "status": "modified", "additions": 0, "deletions": 5, "patch": "@@ -10,5 +10,0 @@"}, + ], + reviews=[ + {"id": 1, "user": "bob", "state": "approved", "body": "LGTM!", "submitted_at": "2025-12-23T10:00:00Z"}, + ], + review_comments=[ + {"id": 101, "user": "bob", "body": "Nice refactoring here", "path": "src/main.py", "line": 10, "side": "RIGHT", "diff_hunk": "@@ context", "created_at": "2025-12-23T09:00:00Z"}, + ], + ) + pr_data.diff = "diff --git a/src/main.py b/src/main.py\n..." + + pr.pr_data = pr_data + db_session.add(pr) + db_session.commit() + db_session.refresh(pr) + + return pr + + +@pytest.mark.asyncio +async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_data): + """Test that github_issue_details includes PR data for PRs.""" + from memory.api.MCP.github import github_issue_details + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_issue_details(repo="owner/repo1", number=999) + + assert result["kind"] == "pr" + assert "pr_data" in result + assert result["pr_data"]["additions"] == 50 + assert result["pr_data"]["deletions"] == 10 + assert result["pr_data"]["changed_files_count"] == 3 + assert len(result["pr_data"]["files"]) == 3 + assert len(result["pr_data"]["reviews"]) == 1 + assert len(result["pr_data"]["review_comments"]) == 1 + assert result["pr_data"]["diff"] is not None + assert "diff --git" in result["pr_data"]["diff"] + + +@pytest.mark.asyncio +async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_issues): + """Test that github_issue_details does not include pr_data for issues.""" + from memory.api.MCP.github import github_issue_details + + with patch("memory.api.MCP.github.make_session", return_value=db_session): + result = await github_issue_details(repo="owner/repo1", number=1) + + assert result["kind"] == "issue" + assert "pr_data" not in result + + +def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data): + """Test that _serialize_issue includes pr_data when include_content=True.""" + from memory.api.MCP.github import _serialize_issue + + result = _serialize_issue(sample_pr_with_data, include_content=True) + + assert "pr_data" in result + assert result["pr_data"]["additions"] == 50 + assert result["pr_data"]["reviews"][0]["state"] == "approved" + + +def test_serialize_issue_no_pr_data_without_content(db_session, sample_pr_with_data): + """Test that _serialize_issue excludes pr_data when include_content=False.""" + from memory.api.MCP.github import _serialize_issue + + result = _serialize_issue(sample_pr_with_data, include_content=False) + + assert "pr_data" not in result + assert "content" not in result diff --git a/tests/memory/api/MCP/test_people.py b/tests/memory/api/MCP/test_people.py new file mode 100644 index 0000000..d4ec183 --- /dev/null +++ b/tests/memory/api/MCP/test_people.py @@ -0,0 +1,482 @@ +"""Tests for People MCP tools.""" + +import sys +import pytest +from unittest.mock import AsyncMock, Mock, MagicMock, patch + +# Mock the mcp module and all its submodules before importing anything that uses it +_mock_mcp = MagicMock() +_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +sys.modules["mcp"] = _mock_mcp +sys.modules["mcp.server"] = MagicMock() +sys.modules["mcp.server.auth"] = MagicMock() +sys.modules["mcp.server.auth.handlers"] = MagicMock() +sys.modules["mcp.server.auth.handlers.authorize"] = MagicMock() +sys.modules["mcp.server.auth.handlers.token"] = MagicMock() +sys.modules["mcp.server.auth.provider"] = MagicMock() +sys.modules["mcp.server.fastmcp"] = MagicMock() +sys.modules["mcp.server.fastmcp.server"] = MagicMock() + +# Also mock the memory.api.MCP.base module to avoid MCP imports +_mock_base = MagicMock() +_mock_base.mcp = MagicMock() +_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +sys.modules["memory.api.MCP.base"] = _mock_base + +from memory.common.db.models import Person +from memory.common.db import connection as db_connection +from memory.workers.tasks.content_processing import create_content_hash + + +@pytest.fixture(autouse=True) +def reset_db_cache(): + """Reset the cached database engine between tests.""" + db_connection._engine = None + db_connection._session_factory = None + db_connection._scoped_session = None + yield + db_connection._engine = None + db_connection._session_factory = None + db_connection._scoped_session = None + + +@pytest.fixture +def sample_people(db_session): + """Create sample people for testing.""" + people = [ + Person( + identifier="alice_chen", + display_name="Alice Chen", + aliases=["@alice_c", "alice.chen@work.com"], + contact_info={"email": "alice@example.com", "phone": "555-1234"}, + tags=["work", "engineering"], + content="Tech lead on Platform team. Very thorough in code reviews.", + modality="person", + sha256=create_content_hash("person:alice_chen"), + size=100, + ), + Person( + identifier="bob_smith", + display_name="Bob Smith", + aliases=["@bobsmith"], + contact_info={"email": "bob@example.com"}, + tags=["work", "design"], + content="UX designer. Prefers visual communication.", + modality="person", + sha256=create_content_hash("person:bob_smith"), + size=50, + ), + Person( + identifier="charlie_jones", + display_name="Charlie Jones", + aliases=[], + contact_info={"twitter": "@charlie_j"}, + tags=["friend", "climbing"], + content="Met at climbing gym. Very reliable.", + modality="person", + sha256=create_content_hash("person:charlie_jones"), + size=30, + ), + ] + + for person in people: + db_session.add(person) + db_session.commit() + + for person in people: + db_session.refresh(person) + + return people + + +# ============================================================================= +# Tests for add_person +# ============================================================================= + + +@pytest.mark.asyncio +async def test_add_person_success(db_session): + """Test adding a new person.""" + from memory.api.MCP.people import add_person + + mock_task = Mock() + mock_task.id = "task-123" + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + with patch("memory.api.MCP.people.celery_app") as mock_celery: + mock_celery.send_task.return_value = mock_task + result = await add_person( + identifier="new_person", + display_name="New Person", + aliases=["@newperson"], + contact_info={"email": "new@example.com"}, + tags=["friend"], + notes="A new friend.", + ) + + assert result["status"] == "queued" + assert result["task_id"] == "task-123" + assert result["identifier"] == "new_person" + + # Verify Celery task was called + mock_celery.send_task.assert_called_once() + call_kwargs = mock_celery.send_task.call_args[1] + assert call_kwargs["kwargs"]["identifier"] == "new_person" + assert call_kwargs["kwargs"]["display_name"] == "New Person" + + +@pytest.mark.asyncio +async def test_add_person_already_exists(db_session, sample_people): + """Test adding a person that already exists.""" + from memory.api.MCP.people import add_person + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + with pytest.raises(ValueError, match="already exists"): + await add_person( + identifier="alice_chen", # Already exists + display_name="Alice Chen Duplicate", + ) + + +@pytest.mark.asyncio +async def test_add_person_minimal(db_session): + """Test adding a person with minimal data.""" + from memory.api.MCP.people import add_person + + mock_task = Mock() + mock_task.id = "task-456" + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + with patch("memory.api.MCP.people.celery_app") as mock_celery: + mock_celery.send_task.return_value = mock_task + result = await add_person( + identifier="minimal_person", + display_name="Minimal Person", + ) + + assert result["status"] == "queued" + assert result["identifier"] == "minimal_person" + + +# ============================================================================= +# Tests for update_person_info +# ============================================================================= + + +@pytest.mark.asyncio +async def test_update_person_info_success(db_session, sample_people): + """Test updating a person's info.""" + from memory.api.MCP.people import update_person_info + + mock_task = Mock() + mock_task.id = "task-789" + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + with patch("memory.api.MCP.people.celery_app") as mock_celery: + mock_celery.send_task.return_value = mock_task + result = await update_person_info( + identifier="alice_chen", + display_name="Alice M. Chen", + notes="Added middle initial", + ) + + assert result["status"] == "queued" + assert result["task_id"] == "task-789" + assert result["identifier"] == "alice_chen" + + +@pytest.mark.asyncio +async def test_update_person_info_not_found(db_session, sample_people): + """Test updating a person that doesn't exist.""" + from memory.api.MCP.people import update_person_info + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + with pytest.raises(ValueError, match="not found"): + await update_person_info( + identifier="nonexistent_person", + display_name="New Name", + ) + + +@pytest.mark.asyncio +async def test_update_person_info_with_merge_params(db_session, sample_people): + """Test that update passes all merge parameters.""" + from memory.api.MCP.people import update_person_info + + mock_task = Mock() + mock_task.id = "task-merge" + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + with patch("memory.api.MCP.people.celery_app") as mock_celery: + mock_celery.send_task.return_value = mock_task + result = await update_person_info( + identifier="alice_chen", + aliases=["@alice_new"], + contact_info={"slack": "@alice"}, + tags=["leadership"], + notes="Promoted to senior", + replace_notes=False, + ) + + call_kwargs = mock_celery.send_task.call_args[1]["kwargs"] + assert call_kwargs["aliases"] == ["@alice_new"] + assert call_kwargs["contact_info"] == {"slack": "@alice"} + assert call_kwargs["tags"] == ["leadership"] + assert call_kwargs["notes"] == "Promoted to senior" + assert call_kwargs["replace_notes"] is False + + +# ============================================================================= +# Tests for get_person +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_person_found(db_session, sample_people): + """Test getting a person that exists.""" + from memory.api.MCP.people import get_person + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + result = await get_person(identifier="alice_chen") + + assert result is not None + assert result["identifier"] == "alice_chen" + assert result["display_name"] == "Alice Chen" + assert result["aliases"] == ["@alice_c", "alice.chen@work.com"] + assert result["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"} + assert result["tags"] == ["work", "engineering"] + assert "Tech lead" in result["notes"] + + +@pytest.mark.asyncio +async def test_get_person_not_found(db_session, sample_people): + """Test getting a person that doesn't exist.""" + from memory.api.MCP.people import get_person + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + result = await get_person(identifier="nonexistent_person") + + assert result is None + + +# ============================================================================= +# Tests for list_people +# ============================================================================= + + +@pytest.mark.asyncio +async def test_list_people_no_filters(db_session, sample_people): + """Test listing all people without filters.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people() + + assert len(results) == 3 + # Should be ordered by display_name + assert results[0]["display_name"] == "Alice Chen" + assert results[1]["display_name"] == "Bob Smith" + assert results[2]["display_name"] == "Charlie Jones" + + +@pytest.mark.asyncio +async def test_list_people_filter_by_tags(db_session, sample_people): + """Test filtering by tags.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people(tags=["work"]) + + assert len(results) == 2 + assert all("work" in r["tags"] for r in results) + + +@pytest.mark.asyncio +async def test_list_people_filter_by_search(db_session, sample_people): + """Test filtering by search term.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people(search="alice") + + assert len(results) == 1 + assert results[0]["identifier"] == "alice_chen" + + +@pytest.mark.asyncio +async def test_list_people_search_in_notes(db_session, sample_people): + """Test that search works on notes content.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people(search="climbing") + + assert len(results) == 1 + assert results[0]["identifier"] == "charlie_jones" + + +@pytest.mark.asyncio +async def test_list_people_limit(db_session, sample_people): + """Test limiting results.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people(limit=1) + + assert len(results) == 1 + + +@pytest.mark.asyncio +async def test_list_people_limit_max_enforced(db_session, sample_people): + """Test that limit is capped at 200.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + # Request 500 but should be capped at 200 + results = await list_people(limit=500) + + # We only have 3 people, but the limit logic should cap at 200 + assert len(results) <= 200 + + +@pytest.mark.asyncio +async def test_list_people_combined_filters(db_session, sample_people): + """Test combining tag and search filters.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people(tags=["work"], search="chen") + + assert len(results) == 1 + assert results[0]["identifier"] == "alice_chen" + + +# ============================================================================= +# Tests for delete_person +# ============================================================================= + + +@pytest.mark.asyncio +async def test_delete_person_success(db_session, sample_people): + """Test deleting a person.""" + from memory.api.MCP.people import delete_person + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + result = await delete_person(identifier="bob_smith") + + assert result["deleted"] is True + assert result["identifier"] == "bob_smith" + assert result["display_name"] == "Bob Smith" + + # Verify person was deleted + remaining = db_session.query(Person).filter_by(identifier="bob_smith").first() + assert remaining is None + + +@pytest.mark.asyncio +async def test_delete_person_not_found(db_session, sample_people): + """Test deleting a person that doesn't exist.""" + from memory.api.MCP.people import delete_person + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + with pytest.raises(ValueError, match="not found"): + await delete_person(identifier="nonexistent_person") + + +# ============================================================================= +# Tests for _person_to_dict helper +# ============================================================================= + + +def test_person_to_dict(sample_people): + """Test the _person_to_dict helper function.""" + from memory.api.MCP.people import _person_to_dict + + person = sample_people[0] + result = _person_to_dict(person) + + assert result["identifier"] == "alice_chen" + assert result["display_name"] == "Alice Chen" + assert result["aliases"] == ["@alice_c", "alice.chen@work.com"] + assert result["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"} + assert result["tags"] == ["work", "engineering"] + assert result["notes"] == "Tech lead on Platform team. Very thorough in code reviews." + assert result["created_at"] is not None + + +def test_person_to_dict_empty_fields(db_session): + """Test _person_to_dict with empty optional fields.""" + from memory.api.MCP.people import _person_to_dict + + person = Person( + identifier="empty_person", + display_name="Empty Person", + aliases=[], + contact_info={}, + tags=[], + content=None, + modality="person", + sha256=create_content_hash("person:empty_person"), + size=0, + ) + + result = _person_to_dict(person) + + assert result["identifier"] == "empty_person" + assert result["aliases"] == [] + assert result["contact_info"] == {} + assert result["tags"] == [] + assert result["notes"] is None + + +# ============================================================================= +# Parametrized tests +# ============================================================================= + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tag,expected_count", + [ + ("work", 2), + ("engineering", 1), + ("design", 1), + ("friend", 1), + ("climbing", 1), + ("nonexistent", 0), + ], +) +async def test_list_people_various_tags(db_session, sample_people, tag, expected_count): + """Test filtering by various tags.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people(tags=[tag]) + + assert len(results) == expected_count + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "search_term,expected_identifiers", + [ + ("alice", ["alice_chen"]), + ("bob", ["bob_smith"]), + ("smith", ["bob_smith"]), + ("chen", ["alice_chen"]), + ("jones", ["charlie_jones"]), + ("example.com", []), # Not searching in contact_info + ("UX", ["bob_smith"]), # Case insensitive search in notes + ], +) +async def test_list_people_various_searches( + db_session, sample_people, search_term, expected_identifiers +): + """Test search with various terms.""" + from memory.api.MCP.people import list_people + + with patch("memory.api.MCP.people.make_session", return_value=db_session): + results = await list_people(search=search_term) + + result_identifiers = [r["identifier"] for r in results] + assert result_identifiers == expected_identifiers diff --git a/tests/memory/api/__init__.py b/tests/memory/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/memory/common/db/models/test_people.py b/tests/memory/common/db/models/test_people.py new file mode 100644 index 0000000..963c11b --- /dev/null +++ b/tests/memory/common/db/models/test_people.py @@ -0,0 +1,249 @@ +"""Tests for the Person model.""" + +import pytest + +from memory.common.db.models import Person +from memory.workers.tasks.content_processing import create_content_hash + + +@pytest.fixture +def person_data(): + """Standard person test data.""" + return { + "identifier": "alice_chen", + "display_name": "Alice Chen", + "aliases": ["@alice_c", "alice.chen@work.com"], + "contact_info": {"email": "alice@example.com", "phone": "555-1234"}, + "tags": ["work", "engineering"], + "content": "Tech lead on Platform team. Very thorough in code reviews.", + "modality": "person", + "mime_type": "text/plain", + } + + +@pytest.fixture +def minimal_person_data(): + """Minimal person test data.""" + return { + "identifier": "bob_smith", + "display_name": "Bob Smith", + "modality": "person", + } + + +def test_person_creation(person_data): + """Test creating a Person with all fields.""" + sha256 = create_content_hash(f"person:{person_data['identifier']}") + person = Person(**person_data, sha256=sha256, size=100) + + assert person.identifier == "alice_chen" + assert person.display_name == "Alice Chen" + assert person.aliases == ["@alice_c", "alice.chen@work.com"] + assert person.contact_info == {"email": "alice@example.com", "phone": "555-1234"} + assert person.tags == ["work", "engineering"] + assert person.content == "Tech lead on Platform team. Very thorough in code reviews." + assert person.modality == "person" + + +def test_person_creation_minimal(minimal_person_data): + """Test creating a Person with minimal fields.""" + sha256 = create_content_hash(f"person:{minimal_person_data['identifier']}") + person = Person(**minimal_person_data, sha256=sha256, size=0) + + assert person.identifier == "bob_smith" + assert person.display_name == "Bob Smith" + assert person.aliases == [] or person.aliases is None + assert person.contact_info == {} or person.contact_info is None + assert person.tags == [] or person.tags is None + assert person.content is None + + +def test_person_as_payload(person_data): + """Test the as_payload method.""" + sha256 = create_content_hash(f"person:{person_data['identifier']}") + person = Person(**person_data, sha256=sha256, size=100) + + payload = person.as_payload() + + assert payload["identifier"] == "alice_chen" + assert payload["display_name"] == "Alice Chen" + assert payload["aliases"] == ["@alice_c", "alice.chen@work.com"] + assert payload["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"} + + +def test_person_display_contents(person_data): + """Test the display_contents property.""" + sha256 = create_content_hash(f"person:{person_data['identifier']}") + person = Person(**person_data, sha256=sha256, size=100) + + contents = person.display_contents + + assert contents["identifier"] == "alice_chen" + assert contents["display_name"] == "Alice Chen" + assert contents["aliases"] == ["@alice_c", "alice.chen@work.com"] + assert contents["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"} + assert contents["notes"] == "Tech lead on Platform team. Very thorough in code reviews." + assert contents["tags"] == ["work", "engineering"] + + +def test_person_chunk_contents(person_data): + """Test the _chunk_contents method generates searchable chunks.""" + sha256 = create_content_hash(f"person:{person_data['identifier']}") + person = Person(**person_data, sha256=sha256, size=100) + + chunks = person._chunk_contents() + + assert len(chunks) > 0 + chunk_text = chunks[0].data[0] + + # Should include display name + assert "Alice Chen" in chunk_text + # Should include aliases + assert "@alice_c" in chunk_text + # Should include tags + assert "work" in chunk_text + # Should include notes/content + assert "Tech lead" in chunk_text + + +def test_person_chunk_contents_minimal(minimal_person_data): + """Test _chunk_contents with minimal data.""" + sha256 = create_content_hash(f"person:{minimal_person_data['identifier']}") + person = Person(**minimal_person_data, sha256=sha256, size=0) + + chunks = person._chunk_contents() + + assert len(chunks) > 0 + chunk_text = chunks[0].data[0] + assert "Bob Smith" in chunk_text + + +def test_person_get_collections(): + """Test that Person returns correct collections.""" + collections = Person.get_collections() + + assert collections == ["person"] + + +def test_person_polymorphic_identity(): + """Test that Person has correct polymorphic identity.""" + assert Person.__mapper_args__["polymorphic_identity"] == "person" + + +@pytest.mark.parametrize( + "identifier,display_name,aliases,tags", + [ + ("john_doe", "John Doe", [], []), + ("jane_smith", "Jane Smith", ["@jane"], ["friend"]), + ("bob_jones", "Bob Jones", ["@bob", "bobby"], ["work", "climbing", "london"]), + ( + "alice_wong", + "Alice Wong", + ["@alice", "alice@work.com", "Alice W."], + ["family", "close"], + ), + ], +) +def test_person_various_configurations(identifier, display_name, aliases, tags): + """Test Person creation with various configurations.""" + sha256 = create_content_hash(f"person:{identifier}") + person = Person( + identifier=identifier, + display_name=display_name, + aliases=aliases, + tags=tags, + modality="person", + sha256=sha256, + size=0, + ) + + assert person.identifier == identifier + assert person.display_name == display_name + assert person.aliases == aliases + assert person.tags == tags + + +def test_person_contact_info_flexible(): + """Test that contact_info can hold various structures.""" + contact_info = { + "email": "test@example.com", + "phone": "+1-555-1234", + "twitter": "@testuser", + "linkedin": "linkedin.com/in/testuser", + "address": { + "street": "123 Main St", + "city": "San Francisco", + "country": "USA", + }, + } + + sha256 = create_content_hash("person:test_user") + person = Person( + identifier="test_user", + display_name="Test User", + contact_info=contact_info, + modality="person", + sha256=sha256, + size=0, + ) + + assert person.contact_info == contact_info + assert person.contact_info["address"]["city"] == "San Francisco" + + +def test_person_in_db(db_session, qdrant): + """Test Person persistence in database.""" + sha256 = create_content_hash("person:db_test_user") + person = Person( + identifier="db_test_user", + display_name="DB Test User", + aliases=["@dbtest"], + contact_info={"email": "dbtest@example.com"}, + tags=["test"], + content="Test notes", + modality="person", + mime_type="text/plain", + sha256=sha256, + size=10, + ) + + db_session.add(person) + db_session.commit() + + # Query it back + retrieved = db_session.query(Person).filter_by(identifier="db_test_user").first() + + assert retrieved is not None + assert retrieved.display_name == "DB Test User" + assert retrieved.aliases == ["@dbtest"] + assert retrieved.contact_info == {"email": "dbtest@example.com"} + assert retrieved.tags == ["test"] + assert retrieved.content == "Test notes" + + +def test_person_unique_identifier(db_session, qdrant): + """Test that identifier must be unique.""" + sha256 = create_content_hash("person:unique_test") + + person1 = Person( + identifier="unique_test", + display_name="Person 1", + modality="person", + sha256=sha256, + size=0, + ) + db_session.add(person1) + db_session.commit() + + # Try to add another with same identifier + person2 = Person( + identifier="unique_test", + display_name="Person 2", + modality="person", + sha256=create_content_hash("person:unique_test_2"), + size=0, + ) + db_session.add(person2) + + with pytest.raises(Exception): # Should raise IntegrityError + db_session.commit() diff --git a/tests/memory/parsers/test_github.py b/tests/memory/parsers/test_github.py index bc11eeb..ce37c67 100644 --- a/tests/memory/parsers/test_github.py +++ b/tests/memory/parsers/test_github.py @@ -333,7 +333,13 @@ def test_fetch_prs_basic(): page = kwargs.get("params", {}).get("page", 1) - if "/pulls" in url and "/comments" not in url: + # PR list endpoint + if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url: + # Check if this is the diff request + if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff": + response.ok = True + response.text = "+100 lines added\n-50 lines removed" + return response if page == 1: response.json.return_value = [ { @@ -355,10 +361,22 @@ def test_fetch_prs_basic(): ] else: response.json.return_value = [] - elif ".diff" in url: - response.ok = True - response.text = "+100 lines added\n-50 lines removed" - elif "/comments" in url: + elif "/pulls/" in url and "/comments" in url: + # Review comments endpoint + response.json.return_value = [] + elif "/pulls/" in url and "/reviews" in url: + # Reviews endpoint + response.json.return_value = [] + elif "/pulls/" in url and "/files" in url: + # Files endpoint + if page == 1: + response.json.return_value = [ + {"filename": "test.py", "status": "added", "additions": 100, "deletions": 50, "patch": "+code"} + ] + else: + response.json.return_value = [] + elif "/issues/" in url and "/comments" in url: + # Regular comments endpoint response.json.return_value = [] else: response.json.return_value = [] @@ -376,43 +394,66 @@ def test_fetch_prs_basic(): assert pr["kind"] == "pr" assert pr["diff_summary"] is not None assert "100 lines added" in pr["diff_summary"] + # Verify pr_data is populated + assert pr["pr_data"] is not None + assert pr["pr_data"]["additions"] == 100 + assert pr["pr_data"]["deletions"] == 50 def test_fetch_prs_merged(): """Test fetching merged PR.""" credentials = GithubCredentials(auth_type="pat", access_token="token") - mock_response = Mock() - mock_response.json.return_value = [ - { - "number": 20, - "title": "Merged PR", - "body": "Body", - "state": "closed", - "user": {"login": "user"}, - "labels": [], - "assignees": [], - "milestone": None, - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-10T00:00:00Z", - "closed_at": "2024-01-10T00:00:00Z", - "merged_at": "2024-01-10T00:00:00Z", - "additions": 10, - "deletions": 5, - "comments": 0, - } - ] - mock_response.headers = {"X-RateLimit-Remaining": "4999"} - mock_response.raise_for_status = Mock() + def mock_get(url, **kwargs): + """Route mock responses based on URL.""" + response = Mock() + response.headers = {"X-RateLimit-Remaining": "4999"} + response.raise_for_status = Mock() - mock_empty = Mock() - mock_empty.json.return_value = [] - mock_empty.headers = {"X-RateLimit-Remaining": "4998"} - mock_empty.raise_for_status = Mock() + page = kwargs.get("params", {}).get("page", 1) - with patch.object(requests.Session, "get") as mock_get: - mock_get.side_effect = [mock_response, mock_empty, mock_empty] + # PR list endpoint + if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url: + if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff": + response.ok = True + response.text = "" + return response + if page == 1: + response.json.return_value = [ + { + "number": 20, + "title": "Merged PR", + "body": "Body", + "state": "closed", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-10T00:00:00Z", + "closed_at": "2024-01-10T00:00:00Z", + "merged_at": "2024-01-10T00:00:00Z", + "additions": 10, + "deletions": 5, + "comments": 0, + } + ] + else: + response.json.return_value = [] + elif "/pulls/" in url and "/comments" in url: + response.json.return_value = [] + elif "/pulls/" in url and "/reviews" in url: + response.json.return_value = [] + elif "/pulls/" in url and "/files" in url: + response.json.return_value = [] + elif "/issues/" in url and "/comments" in url: + response.json.return_value = [] + else: + response.json.return_value = [] + return response + + with patch.object(requests.Session, "get", side_effect=mock_get): client = GithubClient(credentials) prs = list(client.fetch_prs("owner", "repo")) @@ -425,54 +466,73 @@ def test_fetch_prs_stops_at_since(): """Test that PR fetching stops when reaching older items.""" credentials = GithubCredentials(auth_type="pat", access_token="token") - mock_response = Mock() - mock_response.json.return_value = [ - { - "number": 30, - "title": "Recent PR", - "body": "Body", - "state": "open", - "user": {"login": "user"}, - "labels": [], - "assignees": [], - "milestone": None, - "created_at": "2024-01-20T00:00:00Z", - "updated_at": "2024-01-20T00:00:00Z", - "closed_at": None, - "merged_at": None, - "additions": 1, - "deletions": 1, - "comments": 0, - }, - { - "number": 29, - "title": "Old PR", - "body": "Body", - "state": "open", - "user": {"login": "user"}, - "labels": [], - "assignees": [], - "milestone": None, - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-01T00:00:00Z", # Older than since - "closed_at": None, - "merged_at": None, - "additions": 1, - "deletions": 1, - "comments": 0, - }, - ] - mock_response.headers = {"X-RateLimit-Remaining": "4999"} - mock_response.raise_for_status = Mock() + def mock_get(url, **kwargs): + """Route mock responses based on URL.""" + response = Mock() + response.headers = {"X-RateLimit-Remaining": "4999"} + response.raise_for_status = Mock() - mock_empty = Mock() - mock_empty.json.return_value = [] - mock_empty.headers = {"X-RateLimit-Remaining": "4998"} - mock_empty.raise_for_status = Mock() + page = kwargs.get("params", {}).get("page", 1) - with patch.object(requests.Session, "get") as mock_get: - mock_get.side_effect = [mock_response, mock_empty] + # PR list endpoint + if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url: + if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff": + response.ok = True + response.text = "" + return response + if page == 1: + response.json.return_value = [ + { + "number": 30, + "title": "Recent PR", + "body": "Body", + "state": "open", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-20T00:00:00Z", + "updated_at": "2024-01-20T00:00:00Z", + "closed_at": None, + "merged_at": None, + "additions": 1, + "deletions": 1, + "comments": 0, + }, + { + "number": 29, + "title": "Old PR", + "body": "Body", + "state": "open", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", # Older than since + "closed_at": None, + "merged_at": None, + "additions": 1, + "deletions": 1, + "comments": 0, + }, + ] + else: + response.json.return_value = [] + elif "/pulls/" in url and "/comments" in url: + response.json.return_value = [] + elif "/pulls/" in url and "/reviews" in url: + response.json.return_value = [] + elif "/pulls/" in url and "/files" in url: + response.json.return_value = [] + elif "/issues/" in url and "/comments" in url: + response.json.return_value = [] + else: + response.json.return_value = [] + return response + + with patch.object(requests.Session, "get", side_effect=mock_get): client = GithubClient(credentials) since = datetime(2024, 1, 15, tzinfo=timezone.utc) prs = list(client.fetch_prs("owner", "repo", since=since)) @@ -703,3 +763,552 @@ def test_fetch_issues_handles_api_error(): with pytest.raises(requests.HTTPError): list(client.fetch_issues("owner", "nonexistent")) + + +# ============================================================================= +# Tests for fetch_review_comments +# ============================================================================= + + +def test_fetch_review_comments_basic(): + """Test fetching PR review comments.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "id": 1001, + "user": {"login": "reviewer1"}, + "body": "This needs a test", + "path": "src/main.py", + "line": 42, + "side": "RIGHT", + "diff_hunk": "@@ -40,3 +40,5 @@", + "created_at": "2024-01-01T12:00:00Z", + }, + { + "id": 1002, + "user": {"login": "reviewer2"}, + "body": "Good refactoring", + "path": "src/utils.py", + "line": 10, + "side": "LEFT", + "diff_hunk": "@@ -8,5 +8,5 @@", + "created_at": "2024-01-02T10:00:00Z", + }, + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + comments = client.fetch_review_comments("owner", "repo", 10) + + assert len(comments) == 2 + assert comments[0]["user"] == "reviewer1" + assert comments[0]["body"] == "This needs a test" + assert comments[0]["path"] == "src/main.py" + assert comments[0]["line"] == 42 + assert comments[0]["side"] == "RIGHT" + assert comments[0]["diff_hunk"] == "@@ -40,3 +40,5 @@" + assert comments[1]["user"] == "reviewer2" + + +def test_fetch_review_comments_ghost_user(): + """Test review comments with deleted user.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "id": 1001, + "user": None, # Deleted user + "body": "Legacy comment", + "path": "file.py", + "line": None, # Line might be None for outdated comments + "side": "RIGHT", + "diff_hunk": "", + "created_at": "2024-01-01T00:00:00Z", + } + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + comments = client.fetch_review_comments("owner", "repo", 10) + + assert len(comments) == 1 + assert comments[0]["user"] == "ghost" + assert comments[0]["line"] is None + + +def test_fetch_review_comments_pagination(): + """Test review comment fetching with pagination.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_page1 = Mock() + mock_page1.json.return_value = [ + { + "id": i, + "user": {"login": f"user{i}"}, + "body": f"Comment {i}", + "path": "file.py", + "line": i, + "side": "RIGHT", + "diff_hunk": "", + "created_at": "2024-01-01T00:00:00Z", + } + for i in range(100) + ] + mock_page1.headers = {"X-RateLimit-Remaining": "4999"} + mock_page1.raise_for_status = Mock() + + mock_page2 = Mock() + mock_page2.json.return_value = [ + { + "id": 100, + "user": {"login": "user100"}, + "body": "Final comment", + "path": "file.py", + "line": 100, + "side": "RIGHT", + "diff_hunk": "", + "created_at": "2024-01-01T00:00:00Z", + } + ] + mock_page2.headers = {"X-RateLimit-Remaining": "4998"} + mock_page2.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4997"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_page1, mock_page2, mock_empty] + + client = GithubClient(credentials) + comments = client.fetch_review_comments("owner", "repo", 10) + + assert len(comments) == 101 + + +# ============================================================================= +# Tests for fetch_reviews +# ============================================================================= + + +def test_fetch_reviews_basic(): + """Test fetching PR reviews.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "id": 2001, + "user": {"login": "lead_dev"}, + "state": "APPROVED", + "body": "LGTM!", + "submitted_at": "2024-01-05T15:00:00Z", + }, + { + "id": 2002, + "user": {"login": "qa_engineer"}, + "state": "CHANGES_REQUESTED", + "body": "Please add tests", + "submitted_at": "2024-01-04T10:00:00Z", + }, + { + "id": 2003, + "user": {"login": "observer"}, + "state": "COMMENTED", + "body": None, # Some reviews have no body + "submitted_at": "2024-01-03T08:00:00Z", + }, + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + reviews = client.fetch_reviews("owner", "repo", 10) + + assert len(reviews) == 3 + assert reviews[0]["user"] == "lead_dev" + assert reviews[0]["state"] == "approved" # Lowercased + assert reviews[0]["body"] == "LGTM!" + assert reviews[1]["state"] == "changes_requested" + assert reviews[2]["body"] is None + + +def test_fetch_reviews_ghost_user(): + """Test reviews with deleted user.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "id": 2001, + "user": None, + "state": "APPROVED", + "body": "Approved by former employee", + "submitted_at": "2024-01-01T00:00:00Z", + } + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + reviews = client.fetch_reviews("owner", "repo", 10) + + assert len(reviews) == 1 + assert reviews[0]["user"] == "ghost" + + +# ============================================================================= +# Tests for fetch_pr_files +# ============================================================================= + + +def test_fetch_pr_files_basic(): + """Test fetching PR file changes.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "filename": "src/main.py", + "status": "modified", + "additions": 10, + "deletions": 5, + "patch": "@@ -1,5 +1,10 @@\n+new code\n-old code", + }, + { + "filename": "src/new_feature.py", + "status": "added", + "additions": 100, + "deletions": 0, + "patch": "@@ -0,0 +1,100 @@\n+entire new file", + }, + { + "filename": "old_file.py", + "status": "removed", + "additions": 0, + "deletions": 50, + "patch": "@@ -1,50 +0,0 @@\n-entire old file", + }, + { + "filename": "image.png", + "status": "added", + "additions": 0, + "deletions": 0, + # No patch for binary files + }, + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + files = client.fetch_pr_files("owner", "repo", 10) + + assert len(files) == 4 + assert files[0]["filename"] == "src/main.py" + assert files[0]["status"] == "modified" + assert files[0]["additions"] == 10 + assert files[0]["deletions"] == 5 + assert files[0]["patch"] is not None + + assert files[1]["status"] == "added" + assert files[2]["status"] == "removed" + assert files[3]["patch"] is None # Binary file + + +def test_fetch_pr_files_renamed(): + """Test PR with renamed files.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "filename": "new_name.py", + "status": "renamed", + "additions": 0, + "deletions": 0, + "patch": None, + } + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + files = client.fetch_pr_files("owner", "repo", 10) + + assert len(files) == 1 + assert files[0]["status"] == "renamed" + + +# ============================================================================= +# Tests for fetch_pr_diff +# ============================================================================= + + +def test_fetch_pr_diff_success(): + """Test fetching full PR diff.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + diff_text = """diff --git a/file.py b/file.py +index abc123..def456 100644 +--- a/file.py ++++ b/file.py +@@ -1,5 +1,10 @@ ++import os ++ + def main(): +- print("old") ++ print("new") +""" + + mock_response = Mock() + mock_response.ok = True + mock_response.text = diff_text + + with patch.object(requests.Session, "get") as mock_get: + mock_get.return_value = mock_response + + client = GithubClient(credentials) + diff = client.fetch_pr_diff("owner", "repo", 10) + + assert diff == diff_text + # Verify Accept header was set for diff format + call_kwargs = mock_get.call_args.kwargs + assert call_kwargs["headers"]["Accept"] == "application/vnd.github.diff" + + +def test_fetch_pr_diff_failure(): + """Test handling diff fetch failure gracefully.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.ok = False + + with patch.object(requests.Session, "get") as mock_get: + mock_get.return_value = mock_response + + client = GithubClient(credentials) + diff = client.fetch_pr_diff("owner", "repo", 10) + + assert diff is None + + +def test_fetch_pr_diff_exception(): + """Test handling exceptions during diff fetch.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = requests.RequestException("Network error") + + client = GithubClient(credentials) + diff = client.fetch_pr_diff("owner", "repo", 10) + + assert diff is None + + +# ============================================================================= +# Tests for _parse_pr with pr_data +# ============================================================================= + + +def test_parse_pr_fetches_all_pr_data(): + """Test that _parse_pr fetches and includes all PR-specific data.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + pr_raw = { + "number": 42, + "title": "Feature PR", + "body": "PR description", + "state": "open", + "user": {"login": "contributor"}, + "labels": [{"name": "enhancement"}], + "assignees": [{"login": "reviewer"}], + "milestone": {"title": "v2.0"}, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "closed_at": None, + "merged_at": None, + } + + # Mock responses for all the fetch methods + def mock_get(url, **kwargs): + response = Mock() + response.headers = {"X-RateLimit-Remaining": "4999"} + response.raise_for_status = Mock() + + if "/issues/42/comments" in url: + # Regular comments + page = kwargs.get("params", {}).get("page", 1) + if page == 1: + response.json.return_value = [ + { + "id": 1, + "user": {"login": "user1"}, + "body": "Regular comment", + "created_at": "2024-01-01T10:00:00Z", + "updated_at": "2024-01-01T10:00:00Z", + } + ] + else: + response.json.return_value = [] + elif "/pulls/42/comments" in url: + # Review comments + page = kwargs.get("params", {}).get("page", 1) + if page == 1: + response.json.return_value = [ + { + "id": 101, + "user": {"login": "reviewer1"}, + "body": "Review comment", + "path": "src/main.py", + "line": 10, + "side": "RIGHT", + "diff_hunk": "@@ -1,5 +1,10 @@", + "created_at": "2024-01-01T12:00:00Z", + } + ] + else: + response.json.return_value = [] + elif "/pulls/42/reviews" in url: + # Reviews + page = kwargs.get("params", {}).get("page", 1) + if page == 1: + response.json.return_value = [ + { + "id": 201, + "user": {"login": "lead"}, + "state": "APPROVED", + "body": "LGTM", + "submitted_at": "2024-01-02T08:00:00Z", + } + ] + else: + response.json.return_value = [] + elif "/pulls/42/files" in url: + # Files + page = kwargs.get("params", {}).get("page", 1) + if page == 1: + response.json.return_value = [ + { + "filename": "src/main.py", + "status": "modified", + "additions": 50, + "deletions": 10, + "patch": "+new\n-old", + }, + { + "filename": "tests/test_main.py", + "status": "added", + "additions": 30, + "deletions": 0, + "patch": "+tests", + }, + ] + else: + response.json.return_value = [] + elif "/pulls/42" in url and "diff" in kwargs.get("headers", {}).get( + "Accept", "" + ): + # Full diff + response.ok = True + response.text = "diff --git a/src/main.py\n+new code\n-old code" + return response + else: + response.json.return_value = [] + + return response + + with patch.object(requests.Session, "get", side_effect=mock_get): + client = GithubClient(credentials) + result = client._parse_pr("owner", "repo", pr_raw) + + # Verify basic fields + assert result["kind"] == "pr" + assert result["number"] == 42 + assert result["title"] == "Feature PR" + assert result["author"] == "contributor" + assert len(result["comments"]) == 1 + + # Verify pr_data + pr_data = result["pr_data"] + assert pr_data is not None + + # Verify diff + assert pr_data["diff"] is not None + assert "new code" in pr_data["diff"] + + # Verify files + assert len(pr_data["files"]) == 2 + assert pr_data["files"][0]["filename"] == "src/main.py" + assert pr_data["files"][0]["additions"] == 50 + + # Verify stats calculated from files + assert pr_data["additions"] == 80 # 50 + 30 + assert pr_data["deletions"] == 10 + assert pr_data["changed_files_count"] == 2 + + # Verify reviews + assert len(pr_data["reviews"]) == 1 + assert pr_data["reviews"][0]["user"] == "lead" + assert pr_data["reviews"][0]["state"] == "approved" + + # Verify review comments + assert len(pr_data["review_comments"]) == 1 + assert pr_data["review_comments"][0]["user"] == "reviewer1" + assert pr_data["review_comments"][0]["path"] == "src/main.py" + + # Verify diff_summary is truncated version of full diff + assert result["diff_summary"] == pr_data["diff"][:5000] diff --git a/tests/memory/workers/tasks/test_github_tasks.py b/tests/memory/workers/tasks/test_github_tasks.py index 571cf1d..1dee26d 100644 --- a/tests/memory/workers/tasks/test_github_tasks.py +++ b/tests/memory/workers/tasks/test_github_tasks.py @@ -1123,3 +1123,521 @@ def test_tag_merging(repo_tags, issue_labels, expected_tags, github_account, db_ item = db_session.query(GithubItem).filter_by(number=60).first() assert item.tags == expected_tags + + +# ============================================================================= +# Tests for PR data handling +# ============================================================================= + + +@pytest.fixture +def mock_pr_data_with_extended() -> GithubIssueData: + """Mock PR data with full pr_data dict.""" + from memory.parsers.github import GithubPRDataDict + + pr_data: GithubPRDataDict = { + "diff": "diff --git a/file.py\n+new line\n-old line", + "files": [ + { + "filename": "src/main.py", + "status": "modified", + "additions": 50, + "deletions": 10, + "patch": "+new\n-old", + }, + { + "filename": "tests/test_main.py", + "status": "added", + "additions": 30, + "deletions": 0, + "patch": "+test code", + }, + ], + "additions": 80, + "deletions": 10, + "changed_files_count": 2, + "reviews": [ + { + "id": 1001, + "user": "lead_reviewer", + "state": "approved", + "body": "LGTM!", + "submitted_at": "2024-01-02T10:00:00Z", + } + ], + "review_comments": [ + { + "id": 2001, + "user": "reviewer1", + "body": "Please add a docstring here", + "path": "src/main.py", + "line": 42, + "side": "RIGHT", + "diff_hunk": "@@ -40,3 +40,5 @@", + "created_at": "2024-01-01T15:00:00Z", + } + ], + } + + return GithubIssueData( + kind="pr", + number=200, + title="Feature: Add new capability", + body="This PR adds a new capability to the system.", + state="open", + author="contributor", + labels=["enhancement", "needs-review"], + assignees=["reviewer1"], + milestone="v2.0", + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc), + comment_count=1, + comments=[ + { + "id": 5001, + "author": "maintainer", + "body": "Thanks for the PR!", + "created_at": "2024-01-01T14:00:00Z", + "updated_at": "2024-01-01T14:00:00Z", + } + ], + diff_summary="+new\n-old", + project_fields=None, + content_hash="pr_extended_hash", + pr_data=pr_data, + ) + + +def test_build_content_with_review_comments(mock_pr_data_with_extended): + """Test _build_content includes review comments for PRs.""" + content = _build_content(mock_pr_data_with_extended) + + # Basic content + assert "# Feature: Add new capability" in content + assert "This PR adds a new capability" in content + + # Regular comment + assert "**maintainer**: Thanks for the PR!" in content + + # Review comments section + assert "## Code Review Comments" in content + assert "**reviewer1**" in content + assert "Please add a docstring here" in content + assert "`src/main.py`" in content + + +def test_build_content_pr_without_review_comments(): + """Test _build_content for PR with no review comments.""" + data = GithubIssueData( + kind="pr", + number=201, + title="Simple PR", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="hash", + pr_data={ + "diff": None, + "files": [], + "additions": 0, + "deletions": 0, + "changed_files_count": 0, + "reviews": [], + "review_comments": [], # Empty + }, + ) + + content = _build_content(data) + + assert "# Simple PR" in content + assert "## Code Review Comments" not in content + + +def test_build_content_issue_no_pr_data(): + """Test _build_content for issue (no pr_data).""" + data = GithubIssueData( + kind="issue", + number=100, + title="Bug Report", + body="There's a bug", + state="open", + author="reporter", + labels=["bug"], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="hash", + pr_data=None, # Issues don't have pr_data + ) + + content = _build_content(data) + + assert "# Bug Report" in content + assert "There's a bug" in content + assert "## Code Review Comments" not in content + + +def test_create_pr_data_function(mock_pr_data_with_extended): + """Test _create_pr_data creates GithubPRData correctly.""" + from memory.workers.tasks.github import _create_pr_data + + result = _create_pr_data(mock_pr_data_with_extended) + + assert result is not None + assert result.additions == 80 + assert result.deletions == 10 + assert result.changed_files_count == 2 + + # Files are stored as JSONB + assert len(result.files) == 2 + assert result.files[0]["filename"] == "src/main.py" + + # Reviews + assert len(result.reviews) == 1 + assert result.reviews[0]["user"] == "lead_reviewer" + assert result.reviews[0]["state"] == "approved" + + # Review comments + assert len(result.review_comments) == 1 + assert result.review_comments[0]["path"] == "src/main.py" + + # Diff is compressed - test the property getter + assert result.diff is not None + assert "new line" in result.diff + + +def test_create_pr_data_none_for_issue(): + """Test _create_pr_data returns None for issues.""" + from memory.workers.tasks.github import _create_pr_data + + data = GithubIssueData( + kind="issue", + number=100, + title="Issue", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="hash", + pr_data=None, + ) + + result = _create_pr_data(data) + assert result is None + + +def test_serialize_deserialize_with_pr_data(mock_pr_data_with_extended): + """Test serialization roundtrip preserves pr_data.""" + serialized = _serialize_issue_data(mock_pr_data_with_extended) + + # Verify pr_data is included in serialized form + assert "pr_data" in serialized + assert serialized["pr_data"]["additions"] == 80 + assert len(serialized["pr_data"]["files"]) == 2 + assert len(serialized["pr_data"]["reviews"]) == 1 + assert len(serialized["pr_data"]["review_comments"]) == 1 + + # Deserialize and verify + deserialized = _deserialize_issue_data(serialized) + + assert deserialized["pr_data"] is not None + assert deserialized["pr_data"]["additions"] == 80 + assert deserialized["pr_data"]["deletions"] == 10 + assert len(deserialized["pr_data"]["files"]) == 2 + assert deserialized["pr_data"]["diff"] == mock_pr_data_with_extended["pr_data"]["diff"] + + +def test_serialize_deserialize_without_pr_data(mock_issue_data): + """Test serialization roundtrip for issue without pr_data.""" + # Add pr_data=None to the mock (issues don't have it) + issue_with_none = dict(mock_issue_data) + issue_with_none["pr_data"] = None + + serialized = _serialize_issue_data(issue_with_none) + assert serialized.get("pr_data") is None + + deserialized = _deserialize_issue_data(serialized) + assert deserialized.get("pr_data") is None + + +def test_sync_github_item_creates_pr_data( + mock_pr_data_with_extended, github_repo, db_session, qdrant +): + """Test that syncing a PR creates associated GithubPRData.""" + serialized = _serialize_issue_data(mock_pr_data_with_extended) + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "processed" + + # Query the created item + item = ( + db_session.query(GithubItem) + .filter_by(repo_path="testorg/testrepo", number=200, kind="pr") + .first() + ) + assert item is not None + assert item.kind == "pr" + + # Check pr_data relationship + assert item.pr_data is not None + assert item.pr_data.additions == 80 + assert item.pr_data.deletions == 10 + assert item.pr_data.changed_files_count == 2 + assert len(item.pr_data.files) == 2 + assert len(item.pr_data.reviews) == 1 + assert len(item.pr_data.review_comments) == 1 + + # Verify diff decompression works + assert item.pr_data.diff is not None + assert "new line" in item.pr_data.diff + + +def test_sync_github_item_pr_without_pr_data(github_repo, db_session, qdrant): + """Test syncing a PR that doesn't have extended pr_data.""" + data = GithubIssueData( + kind="pr", + number=202, + title="Legacy PR", + body="PR without extended data", + state="merged", + author="old_contributor", + labels=[], + assignees=[], + milestone=None, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + closed_at=datetime(2024, 1, 5, tzinfo=timezone.utc), + merged_at=datetime(2024, 1, 5, tzinfo=timezone.utc), + github_updated_at=datetime(2024, 1, 5, tzinfo=timezone.utc), + comment_count=0, + comments=[], + diff_summary="+10 -5", + project_fields=None, + content_hash="legacy_hash", + pr_data=None, # No extended PR data + ) + + serialized = _serialize_issue_data(data) + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "processed" + + item = db_session.query(GithubItem).filter_by(number=202).first() + assert item is not None + assert item.kind == "pr" + assert item.pr_data is None # No pr_data created + + +def test_sync_github_item_updates_existing_pr_data(github_repo, db_session, qdrant): + """Test updating an existing PR with new pr_data.""" + from memory.common.db.models import GithubPRData + from memory.workers.tasks.content_processing import create_content_hash + + # Create initial PR with pr_data + initial_content = "# Initial PR\n\nOriginal body" + existing_item = GithubItem( + repo_path="testorg/testrepo", + repo_id=github_repo.id, + number=300, + kind="pr", + title="Initial PR", + content=initial_content, + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + comment_count=0, + content_hash="initial_hash", + diff_summary="+5 -2", + modality="github", + mime_type="text/markdown", + sha256=create_content_hash(initial_content), + size=len(initial_content), + tags=["github", "test"], + ) + + # Create initial pr_data + initial_pr_data = GithubPRData( + additions=5, + deletions=2, + changed_files_count=1, + files=[{"filename": "old.py", "status": "modified", "additions": 5, "deletions": 2, "patch": None}], + reviews=[], + review_comments=[], + ) + initial_pr_data.diff = "old diff" + existing_item.pr_data = initial_pr_data + + db_session.add(existing_item) + db_session.commit() + + # Now update with new data + updated_data = GithubIssueData( + kind="pr", + number=300, + title="Updated PR", + body="Updated body with more changes", + state="open", + author="user", + labels=["ready-for-review"], + assignees=["reviewer"], + milestone=None, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 5, tzinfo=timezone.utc), # Newer + comment_count=2, + comments=[ + {"id": 1, "author": "reviewer", "body": "LGTM", "created_at": "", "updated_at": ""} + ], + diff_summary="+50 -10", + project_fields=None, + content_hash="updated_hash", # Different hash triggers update + pr_data={ + "diff": "new diff with lots of changes", + "files": [ + {"filename": "new.py", "status": "added", "additions": 50, "deletions": 0, "patch": "+code"}, + {"filename": "old.py", "status": "modified", "additions": 0, "deletions": 10, "patch": "-code"}, + ], + "additions": 50, + "deletions": 10, + "changed_files_count": 2, + "reviews": [ + {"id": 1, "user": "reviewer", "state": "approved", "body": "Approved!", "submitted_at": ""} + ], + "review_comments": [ + {"id": 1, "user": "reviewer", "body": "Nice!", "path": "new.py", "line": 10, "side": "RIGHT", "diff_hunk": "", "created_at": ""} + ], + }, + ) + + serialized = _serialize_issue_data(updated_data) + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "processed" + + # Refresh from DB + db_session.expire_all() + item = db_session.query(GithubItem).filter_by(number=300).first() + + assert item.title == "Updated PR" + assert item.pr_data is not None + assert item.pr_data.additions == 50 + assert item.pr_data.deletions == 10 + assert item.pr_data.changed_files_count == 2 + assert len(item.pr_data.files) == 2 + assert len(item.pr_data.reviews) == 1 + assert len(item.pr_data.review_comments) == 1 + assert "new diff" in item.pr_data.diff + + +def test_sync_github_item_creates_pr_data_for_existing_pr_without( + github_repo, db_session, qdrant +): + """Test updating a PR that didn't have pr_data to add it.""" + from memory.workers.tasks.content_processing import create_content_hash + + # Create existing PR without pr_data (legacy data) + initial_content = "# Legacy PR\n\nOriginal" + existing_item = GithubItem( + repo_path="testorg/testrepo", + repo_id=github_repo.id, + number=301, + kind="pr", + title="Legacy PR", + content=initial_content, + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + comment_count=0, + content_hash="legacy_hash", + diff_summary=None, + modality="github", + mime_type="text/markdown", + sha256=create_content_hash(initial_content), + size=len(initial_content), + tags=["github"], + pr_data=None, # No pr_data initially + ) + db_session.add(existing_item) + db_session.commit() + + # Update with pr_data + updated_data = GithubIssueData( + kind="pr", + number=301, + title="Legacy PR", + body="Original with new review", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 2, tzinfo=timezone.utc), + comment_count=0, + comments=[], + diff_summary="+10 -0", + project_fields=None, + content_hash="new_hash", # Different + pr_data={ + "diff": "the full diff", + "files": [{"filename": "new.py", "status": "added", "additions": 10, "deletions": 0, "patch": None}], + "additions": 10, + "deletions": 0, + "changed_files_count": 1, + "reviews": [], + "review_comments": [], + }, + ) + + serialized = _serialize_issue_data(updated_data) + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "processed" + + db_session.expire_all() + item = db_session.query(GithubItem).filter_by(number=301).first() + + # Now should have pr_data + assert item.pr_data is not None + assert item.pr_data.additions == 10 + assert item.pr_data.diff == "the full diff" diff --git a/tests/memory/workers/tasks/test_people_tasks.py b/tests/memory/workers/tasks/test_people_tasks.py new file mode 100644 index 0000000..418f085 --- /dev/null +++ b/tests/memory/workers/tasks/test_people_tasks.py @@ -0,0 +1,404 @@ +"""Tests for people Celery tasks.""" + +import uuid +from contextlib import contextmanager +from unittest.mock import patch, MagicMock + +import pytest + +from memory.common.db.models import Person +from memory.common.db.models.source_item import Chunk +from memory.workers.tasks import people +from memory.workers.tasks.content_processing import create_content_hash + + +def _make_mock_chunk(source_id: int) -> Chunk: + """Create a mock chunk for testing with a unique ID.""" + return Chunk( + id=str(uuid.uuid4()), + content="test chunk content", + embedding_model="test-model", + vector=[0.1] * 1024, + item_metadata={"source_id": source_id, "tags": ["test"]}, + collection_name="person", + ) + + +@pytest.fixture +def mock_make_session(db_session): + """Mock make_session and embedding functions for task tests.""" + + @contextmanager + def _mock_session(): + yield db_session + + with patch("memory.workers.tasks.people.make_session", _mock_session): + # Mock embedding to return a fake chunk + with patch( + "memory.common.embedding.embed_source_item", + side_effect=lambda item: [_make_mock_chunk(item.id or 1)], + ): + # Mock push_to_qdrant to do nothing + with patch("memory.workers.tasks.content_processing.push_to_qdrant"): + yield db_session + + +@pytest.fixture +def person_data(): + """Standard person test data.""" + return { + "identifier": "alice_chen", + "display_name": "Alice Chen", + "aliases": ["@alice_c", "alice.chen@work.com"], + "contact_info": {"email": "alice@example.com", "phone": "555-1234"}, + "tags": ["work", "engineering"], + "notes": "Tech lead on Platform team.", + } + + +@pytest.fixture +def minimal_person_data(): + """Minimal person test data.""" + return { + "identifier": "bob_smith", + "display_name": "Bob Smith", + } + + +def test_sync_person_success(person_data, mock_make_session, qdrant): + """Test successful person sync.""" + result = people.sync_person(**person_data) + + # Verify the Person was created in the database + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + assert person is not None + assert person.identifier == "alice_chen" + assert person.display_name == "Alice Chen" + assert person.aliases == ["@alice_c", "alice.chen@work.com"] + assert person.contact_info == {"email": "alice@example.com", "phone": "555-1234"} + assert person.tags == ["work", "engineering"] + assert person.content == "Tech lead on Platform team." + assert person.modality == "person" + + # Verify the result + assert result["status"] == "processed" + assert "person_id" in result + + +def test_sync_person_minimal_data(minimal_person_data, mock_make_session, qdrant): + """Test person sync with minimal required data.""" + result = people.sync_person(**minimal_person_data) + + person = mock_make_session.query(Person).filter_by(identifier="bob_smith").first() + assert person is not None + assert person.identifier == "bob_smith" + assert person.display_name == "Bob Smith" + assert person.aliases == [] + assert person.contact_info == {} + assert person.tags == [] + assert person.content is None + + assert result["status"] == "processed" + + +def test_sync_person_already_exists(person_data, mock_make_session, qdrant): + """Test sync when person already exists.""" + # Create the person first + sha256 = create_content_hash(f"person:{person_data['identifier']}") + existing_person = Person( + identifier=person_data["identifier"], + display_name=person_data["display_name"], + aliases=person_data["aliases"], + contact_info=person_data["contact_info"], + tags=person_data["tags"], + content=person_data["notes"], + modality="person", + mime_type="text/plain", + sha256=sha256, + size=len(person_data["notes"]), + ) + mock_make_session.add(existing_person) + mock_make_session.commit() + + # Try to sync again + result = people.sync_person(**person_data) + + assert result["status"] == "already_exists" + assert result["person_id"] == existing_person.id + + # Verify no duplicate was created + count = mock_make_session.query(Person).filter_by(identifier="alice_chen").count() + assert count == 1 + + +def test_update_person_display_name(person_data, mock_make_session, qdrant): + """Test updating display name.""" + # Create person first + people.sync_person(**person_data) + + # Update display name + result = people.update_person( + identifier="alice_chen", + display_name="Alice M. Chen", + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + assert person.display_name == "Alice M. Chen" + # Other fields unchanged + assert person.aliases == ["@alice_c", "alice.chen@work.com"] + + +def test_update_person_merge_aliases(person_data, mock_make_session, qdrant): + """Test that aliases are merged, not replaced.""" + # Create person first + people.sync_person(**person_data) + + # Update with new aliases + result = people.update_person( + identifier="alice_chen", + aliases=["@alice_chen", "alice@company.com"], + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + # Should be union of old and new + assert set(person.aliases) == { + "@alice_c", + "alice.chen@work.com", + "@alice_chen", + "alice@company.com", + } + + +def test_update_person_merge_contact_info(person_data, mock_make_session, qdrant): + """Test that contact_info is deep merged.""" + # Create person first + people.sync_person(**person_data) + + # Update with new contact info + result = people.update_person( + identifier="alice_chen", + contact_info={"twitter": "@alice_c", "phone": "555-5678"}, # Update existing + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + # Should have all keys + assert person.contact_info["email"] == "alice@example.com" # Original + assert person.contact_info["phone"] == "555-5678" # Updated + assert person.contact_info["twitter"] == "@alice_c" # New + + +def test_update_person_merge_tags(person_data, mock_make_session, qdrant): + """Test that tags are merged, not replaced.""" + # Create person first + people.sync_person(**person_data) + + # Update with new tags + result = people.update_person( + identifier="alice_chen", + tags=["climbing", "london"], + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + # Should be union of old and new + assert set(person.tags) == {"work", "engineering", "climbing", "london"} + + +def test_update_person_append_notes(person_data, mock_make_session, qdrant): + """Test that notes are appended by default.""" + # Create person first + people.sync_person(**person_data) + + # Update with new notes + result = people.update_person( + identifier="alice_chen", + notes="Also enjoys rock climbing.", + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + # Should be appended with separator + assert "Tech lead on Platform team." in person.content + assert "Also enjoys rock climbing." in person.content + assert "---" in person.content + + +def test_update_person_replace_notes(person_data, mock_make_session, qdrant): + """Test replacing notes instead of appending.""" + # Create person first + people.sync_person(**person_data) + + # Replace notes + result = people.update_person( + identifier="alice_chen", + notes="Completely new notes.", + replace_notes=True, + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + assert person.content == "Completely new notes." + assert "Tech lead" not in person.content + + +def test_update_person_not_found(mock_make_session, qdrant): + """Test updating a person that doesn't exist.""" + result = people.update_person( + identifier="nonexistent_person", + display_name="New Name", + ) + + assert result["status"] == "not_found" + assert result["identifier"] == "nonexistent_person" + + +def test_update_person_no_changes(person_data, mock_make_session, qdrant): + """Test update with no actual changes.""" + # Create person first + people.sync_person(**person_data) + + # Update with nothing + result = people.update_person(identifier="alice_chen") + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first() + # Should be unchanged + assert person.display_name == "Alice Chen" + + +@pytest.mark.parametrize( + "identifier,display_name,tags", + [ + ("john_doe", "John Doe", []), + ("jane_smith", "Jane Smith", ["friend"]), + ("bob_jones", "Bob Jones", ["work", "climbing", "london"]), + ], +) +def test_sync_person_various_configurations(identifier, display_name, tags, mock_make_session, qdrant): + """Test sync_person with various configurations.""" + result = people.sync_person( + identifier=identifier, + display_name=display_name, + tags=tags, + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier=identifier).first() + assert person is not None + assert person.display_name == display_name + assert person.tags == tags + + +def test_deep_merge_helper(): + """Test the _deep_merge helper function.""" + base = { + "a": 1, + "b": {"c": 2, "d": 3}, + "e": 4, + } + updates = { + "b": {"c": 5, "f": 6}, + "g": 7, + } + + result = people._deep_merge(base, updates) + + assert result == { + "a": 1, + "b": {"c": 5, "d": 3, "f": 6}, + "e": 4, + "g": 7, + } + + +def test_deep_merge_nested(): + """Test deep merge with deeply nested structures.""" + base = { + "level1": { + "level2": { + "level3": {"a": 1}, + }, + }, + } + updates = { + "level1": { + "level2": { + "level3": {"b": 2}, + }, + }, + } + + result = people._deep_merge(base, updates) + + assert result == { + "level1": { + "level2": { + "level3": {"a": 1, "b": 2}, + }, + }, + } + + +def test_sync_person_unicode(mock_make_session, qdrant): + """Test sync_person with unicode content.""" + result = people.sync_person( + identifier="unicode_person", + display_name="日本語 名前", + notes="Привет мир 🌍", + tags=["日本語", "emoji"], + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="unicode_person").first() + assert person is not None + assert person.display_name == "日本語 名前" + assert person.content == "Привет мир 🌍" + + +def test_sync_person_empty_notes(mock_make_session, qdrant): + """Test sync_person with empty notes.""" + result = people.sync_person( + identifier="empty_notes_person", + display_name="Empty Notes Person", + notes="", + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="empty_notes_person").first() + assert person is not None + assert person.content == "" + + +def test_update_person_first_notes(mock_make_session, qdrant): + """Test adding notes to a person who had no notes.""" + # Create person without notes + people.sync_person( + identifier="no_notes_person", + display_name="No Notes Person", + ) + + # Add notes + result = people.update_person( + identifier="no_notes_person", + notes="First notes!", + ) + + assert result["status"] == "processed" + + person = mock_make_session.query(Person).filter_by(identifier="no_notes_person").first() + assert person.content == "First notes!" + # Should not have separator when there were no previous notes + assert "---" not in person.content