mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
add PRs and People
This commit is contained in:
parent
efbc469ee3
commit
47629fc5fb
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,5 +1,9 @@
|
|||||||
|
|
||||||
Books
|
Books
|
||||||
|
books.md
|
||||||
|
clean_books
|
||||||
|
scripts
|
||||||
|
|
||||||
CLAUDE.md
|
CLAUDE.md
|
||||||
memory_files
|
memory_files
|
||||||
venv
|
venv
|
||||||
|
|||||||
56
db/migrations/versions/20251224_120000_add_people.py
Normal file
56
db/migrations/versions/20251224_120000_add_people.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Add people tracking
|
||||||
|
|
||||||
|
Revision ID: c9d0e1f2a3b4
|
||||||
|
Revises: b7c8d9e0f1a2
|
||||||
|
Create Date: 2025-12-24 12:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "c9d0e1f2a3b4"
|
||||||
|
down_revision: Union[str, None] = "b7c8d9e0f1a2"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"people",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("identifier", sa.Text(), nullable=False),
|
||||||
|
sa.Column("display_name", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"aliases",
|
||||||
|
postgresql.ARRAY(sa.Text()),
|
||||||
|
server_default="{}",
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"contact_info",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
server_default="{}",
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("identifier"),
|
||||||
|
)
|
||||||
|
op.create_index("person_identifier_idx", "people", ["identifier"], unique=False)
|
||||||
|
op.create_index("person_display_name_idx", "people", ["display_name"], unique=False)
|
||||||
|
op.create_index(
|
||||||
|
"person_aliases_idx", "people", ["aliases"], unique=False, postgresql_using="gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("person_aliases_idx", table_name="people")
|
||||||
|
op.drop_index("person_display_name_idx", table_name="people")
|
||||||
|
op.drop_index("person_identifier_idx", table_name="people")
|
||||||
|
op.drop_table("people")
|
||||||
57
db/migrations/versions/20251224_150000_add_github_pr_data.py
Normal file
57
db/migrations/versions/20251224_150000_add_github_pr_data.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Add github_pr_data table for PR-specific data
|
||||||
|
|
||||||
|
Revision ID: d0e1f2a3b4c5
|
||||||
|
Revises: c9d0e1f2a3b4
|
||||||
|
Create Date: 2025-12-24 15:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "d0e1f2a3b4c5"
|
||||||
|
down_revision: Union[str, None] = "c9d0e1f2a3b4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"github_pr_data",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("github_item_id", sa.BigInteger(), nullable=False),
|
||||||
|
# Diff stored compressed with zlib
|
||||||
|
sa.Column("diff_compressed", sa.LargeBinary(), nullable=True),
|
||||||
|
# File changes as structured data
|
||||||
|
# [{filename, status, additions, deletions, patch?}]
|
||||||
|
sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
# Stats
|
||||||
|
sa.Column("additions", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("deletions", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("changed_files_count", sa.Integer(), nullable=True),
|
||||||
|
# Reviews - [{user, state, body, submitted_at}]
|
||||||
|
sa.Column("reviews", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
# Review comments (line-by-line code comments)
|
||||||
|
# [{user, body, path, line, diff_hunk, created_at}]
|
||||||
|
sa.Column(
|
||||||
|
"review_comments", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["github_item_id"], ["github_item.id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("github_item_id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"github_pr_data_item_idx", "github_pr_data", ["github_item_id"], unique=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("github_pr_data_item_idx", table_name="github_pr_data")
|
||||||
|
op.drop_table("github_pr_data")
|
||||||
@ -206,7 +206,7 @@ services:
|
|||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance,notes,scheduler"
|
QUEUES: "backup,blogs,comic,discord,ebooks,email,forums,github,people,photo_embed,maintenance,notes,scheduler"
|
||||||
|
|
||||||
ingest-hub:
|
ingest-hub:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
|
|||||||
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
|
|||||||
git config --global user.name "${GIT_USER_NAME}"
|
git config --global user.name "${GIT_USER_NAME}"
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
ENV QUEUES="backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance"
|
ENV QUEUES="backup,blogs,comic,discord,ebooks,email,forums,github,people,photo_embed,maintenance"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["./entry.sh"]
|
ENTRYPOINT ["./entry.sh"]
|
||||||
@ -4,3 +4,5 @@ import memory.api.MCP.metadata
|
|||||||
import memory.api.MCP.schedules
|
import memory.api.MCP.schedules
|
||||||
import memory.api.MCP.books
|
import memory.api.MCP.books
|
||||||
import memory.api.MCP.manifest
|
import memory.api.MCP.manifest
|
||||||
|
import memory.api.MCP.github
|
||||||
|
import memory.api.MCP.people
|
||||||
|
|||||||
518
src/memory/api/MCP/github.py
Normal file
518
src/memory/api/MCP/github.py
Normal file
@ -0,0 +1,518 @@
|
|||||||
|
"""MCP tools for GitHub issue tracking and management."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import Text, case, desc, func, asc
|
||||||
|
from sqlalchemy import cast as sql_cast
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
|
from memory.api.MCP.base import mcp
|
||||||
|
from memory.api.search.search import search
|
||||||
|
from memory.api.search.types import SearchConfig, SearchFilters
|
||||||
|
from memory.common import extract
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import GithubItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
|
||||||
|
"""Build GitHub URL from repo path and issue/PR number."""
|
||||||
|
if number is None:
|
||||||
|
return f"https://github.com/{repo_path}"
|
||||||
|
url_type = "pull" if kind == "pr" else "issues"
|
||||||
|
return f"https://github.com/{repo_path}/{url_type}/{number}"
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[str, Any]:
|
||||||
|
"""Serialize a GithubItem to a dict for API response."""
|
||||||
|
result = {
|
||||||
|
"id": item.id,
|
||||||
|
"number": item.number,
|
||||||
|
"kind": item.kind,
|
||||||
|
"repo_path": item.repo_path,
|
||||||
|
"title": item.title,
|
||||||
|
"state": item.state,
|
||||||
|
"author": item.author,
|
||||||
|
"assignees": item.assignees or [],
|
||||||
|
"labels": item.labels or [],
|
||||||
|
"milestone": item.milestone,
|
||||||
|
"project_status": item.project_status,
|
||||||
|
"project_priority": item.project_priority,
|
||||||
|
"project_fields": item.project_fields,
|
||||||
|
"comment_count": item.comment_count,
|
||||||
|
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||||
|
"closed_at": item.closed_at.isoformat() if item.closed_at else None,
|
||||||
|
"merged_at": item.merged_at.isoformat() if item.merged_at else None,
|
||||||
|
"github_updated_at": (
|
||||||
|
item.github_updated_at.isoformat() if item.github_updated_at else None
|
||||||
|
),
|
||||||
|
"url": _build_github_url(item.repo_path, item.number, item.kind),
|
||||||
|
}
|
||||||
|
if include_content:
|
||||||
|
result["content"] = item.content
|
||||||
|
# Include PR-specific data if available
|
||||||
|
if item.kind == "pr" and item.pr_data:
|
||||||
|
result["pr_data"] = {
|
||||||
|
"additions": item.pr_data.additions,
|
||||||
|
"deletions": item.pr_data.deletions,
|
||||||
|
"changed_files_count": item.pr_data.changed_files_count,
|
||||||
|
"files": item.pr_data.files,
|
||||||
|
"reviews": item.pr_data.reviews,
|
||||||
|
"review_comments": item.pr_data.review_comments,
|
||||||
|
"diff": item.pr_data.diff, # Decompressed via property
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def list_github_issues(
|
||||||
|
repo: str | None = None,
|
||||||
|
assignee: str | None = None,
|
||||||
|
author: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
kind: str | None = None,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
project_status: str | None = None,
|
||||||
|
project_field: dict[str, str] | None = None,
|
||||||
|
updated_since: str | None = None,
|
||||||
|
updated_before: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
order_by: str = "updated",
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
List GitHub issues and PRs with flexible filtering.
|
||||||
|
Use for daily triage, finding assigned issues, tracking stale issues, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Filter by repository path (e.g., "owner/name")
|
||||||
|
assignee: Filter by assignee username
|
||||||
|
author: Filter by author username
|
||||||
|
state: Filter by state: "open", "closed", "merged" (default: all)
|
||||||
|
kind: Filter by type: "issue" or "pr" (default: both)
|
||||||
|
labels: Filter by GitHub labels (matches ANY label in list)
|
||||||
|
project_status: Filter by project status (e.g., "In Progress", "Backlog")
|
||||||
|
project_field: Filter by project field values (e.g., {"EquiStamp.Client": "Redwood"})
|
||||||
|
updated_since: ISO date - only issues updated after this time
|
||||||
|
updated_before: ISO date - only issues updated before this (for finding stale issues)
|
||||||
|
limit: Maximum results (default 50, max 200)
|
||||||
|
order_by: Sort order: "updated", "created", or "number" (default: "updated" descending)
|
||||||
|
|
||||||
|
Returns: List of issues with id, number, title, state, assignees, labels, project_fields, timestamps, url
|
||||||
|
"""
|
||||||
|
logger.info(f"list_github_issues called: repo={repo}, assignee={assignee}, state={state}")
|
||||||
|
|
||||||
|
limit = min(limit, 200)
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
query = session.query(GithubItem)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if repo:
|
||||||
|
query = query.filter(GithubItem.repo_path == repo)
|
||||||
|
|
||||||
|
if assignee:
|
||||||
|
query = query.filter(GithubItem.assignees.any(assignee))
|
||||||
|
|
||||||
|
if author:
|
||||||
|
query = query.filter(GithubItem.author == author)
|
||||||
|
|
||||||
|
if state:
|
||||||
|
query = query.filter(GithubItem.state == state)
|
||||||
|
|
||||||
|
if kind:
|
||||||
|
query = query.filter(GithubItem.kind == kind)
|
||||||
|
else:
|
||||||
|
# Exclude comments by default, only show issues and PRs
|
||||||
|
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||||
|
|
||||||
|
if labels:
|
||||||
|
# Match any label in the list using PostgreSQL array overlap
|
||||||
|
query = query.filter(
|
||||||
|
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
|
||||||
|
)
|
||||||
|
|
||||||
|
if project_status:
|
||||||
|
query = query.filter(GithubItem.project_status == project_status)
|
||||||
|
|
||||||
|
if project_field:
|
||||||
|
for key, value in project_field.items():
|
||||||
|
query = query.filter(
|
||||||
|
GithubItem.project_fields[key].astext == value
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated_since:
|
||||||
|
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
|
||||||
|
query = query.filter(GithubItem.github_updated_at >= since_dt)
|
||||||
|
|
||||||
|
if updated_before:
|
||||||
|
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
|
||||||
|
query = query.filter(GithubItem.github_updated_at <= before_dt)
|
||||||
|
|
||||||
|
# Apply ordering
|
||||||
|
if order_by == "created":
|
||||||
|
query = query.order_by(desc(GithubItem.created_at))
|
||||||
|
elif order_by == "number":
|
||||||
|
query = query.order_by(desc(GithubItem.number))
|
||||||
|
else: # default: updated
|
||||||
|
query = query.order_by(desc(GithubItem.github_updated_at))
|
||||||
|
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
items = query.all()
|
||||||
|
|
||||||
|
return [_serialize_issue(item) for item in items]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_github_issues(
|
||||||
|
query: str,
|
||||||
|
repo: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
kind: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Search GitHub issues using natural language.
|
||||||
|
Searches across issue titles, bodies, and comments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language search query (e.g., "authentication bug", "database migration")
|
||||||
|
repo: Optional filter by repository path
|
||||||
|
state: Optional filter: "open", "closed", "merged"
|
||||||
|
kind: Optional filter: "issue" or "pr"
|
||||||
|
limit: Maximum results (default 20, max 100)
|
||||||
|
|
||||||
|
Returns: List of matching issues with search score
|
||||||
|
"""
|
||||||
|
logger.info(f"search_github_issues called: query={query}, repo={repo}")
|
||||||
|
|
||||||
|
limit = min(limit, 100)
|
||||||
|
|
||||||
|
# Pre-filter source_ids if repo/state/kind filters are specified
|
||||||
|
source_ids = None
|
||||||
|
if repo or state or kind:
|
||||||
|
with make_session() as session:
|
||||||
|
q = session.query(GithubItem.id)
|
||||||
|
if repo:
|
||||||
|
q = q.filter(GithubItem.repo_path == repo)
|
||||||
|
if state:
|
||||||
|
q = q.filter(GithubItem.state == state)
|
||||||
|
if kind:
|
||||||
|
q = q.filter(GithubItem.kind == kind)
|
||||||
|
else:
|
||||||
|
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||||
|
source_ids = [item.id for item in q.all()]
|
||||||
|
|
||||||
|
# Use the existing search infrastructure
|
||||||
|
data = extract.extract_text(query, skip_summary=True)
|
||||||
|
config = SearchConfig(limit=limit, previews=True)
|
||||||
|
filters = SearchFilters()
|
||||||
|
if source_ids is not None:
|
||||||
|
filters["source_ids"] = source_ids
|
||||||
|
|
||||||
|
results = await search(
|
||||||
|
data,
|
||||||
|
modalities={"github"},
|
||||||
|
filters=filters,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch full issue details for the results
|
||||||
|
output = []
|
||||||
|
with make_session() as session:
|
||||||
|
for result in results:
|
||||||
|
item = session.get(GithubItem, result.id)
|
||||||
|
if item:
|
||||||
|
serialized = _serialize_issue(item)
|
||||||
|
serialized["search_score"] = result.score
|
||||||
|
output.append(serialized)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def github_issue_details(
|
||||||
|
repo: str,
|
||||||
|
number: int,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get full details of a specific GitHub issue or PR including all comments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository path (e.g., "owner/name")
|
||||||
|
number: Issue or PR number
|
||||||
|
|
||||||
|
Returns: Full issue details including content (body + comments), project fields, timestamps.
|
||||||
|
For PRs, also includes pr_data with: diff (full), files changed, reviews, review comments.
|
||||||
|
"""
|
||||||
|
logger.info(f"github_issue_details called: repo={repo}, number={number}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
item = (
|
||||||
|
session.query(GithubItem)
|
||||||
|
.filter(
|
||||||
|
GithubItem.repo_path == repo,
|
||||||
|
GithubItem.number == number,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not item:
|
||||||
|
raise ValueError(f"Issue #{number} not found in {repo}")
|
||||||
|
|
||||||
|
return _serialize_issue(item, include_content=True)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def github_work_summary(
|
||||||
|
since: str,
|
||||||
|
until: str | None = None,
|
||||||
|
group_by: str = "client",
|
||||||
|
repo: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Summarize GitHub work activity for billing and time tracking.
|
||||||
|
Groups issues by client, author, status, or repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
since: ISO date - start of period (e.g., "2025-12-16")
|
||||||
|
until: ISO date - end of period (default: now)
|
||||||
|
group_by: How to group results: "client", "status", "author", "repo", "task_type"
|
||||||
|
repo: Optional filter by repository path
|
||||||
|
|
||||||
|
Returns: Summary with grouped counts and sample issues for each group
|
||||||
|
"""
|
||||||
|
logger.info(f"github_work_summary called: since={since}, group_by={group_by}")
|
||||||
|
|
||||||
|
since_dt = datetime.fromisoformat(since.replace("Z", "+00:00"))
|
||||||
|
if until:
|
||||||
|
until_dt = datetime.fromisoformat(until.replace("Z", "+00:00"))
|
||||||
|
else:
|
||||||
|
until_dt = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Map group_by to SQL expression
|
||||||
|
group_mappings = {
|
||||||
|
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
|
||||||
|
"status": GithubItem.project_status,
|
||||||
|
"author": GithubItem.author,
|
||||||
|
"repo": GithubItem.repo_path,
|
||||||
|
"task_type": GithubItem.project_fields["EquiStamp.Task Type"].astext,
|
||||||
|
}
|
||||||
|
|
||||||
|
if group_by not in group_mappings:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid group_by: {group_by}. Must be one of: {list(group_mappings.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
group_col = group_mappings[group_by]
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Build base query for the period
|
||||||
|
base_query = session.query(GithubItem).filter(
|
||||||
|
GithubItem.github_updated_at >= since_dt,
|
||||||
|
GithubItem.github_updated_at <= until_dt,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if repo:
|
||||||
|
base_query = base_query.filter(GithubItem.repo_path == repo)
|
||||||
|
|
||||||
|
# Get aggregated counts by group
|
||||||
|
agg_query = (
|
||||||
|
session.query(
|
||||||
|
group_col.label("group_name"),
|
||||||
|
func.count(GithubItem.id).label("total"),
|
||||||
|
func.count(case((GithubItem.kind == "issue", 1))).label("issue_count"),
|
||||||
|
func.count(case((GithubItem.kind == "pr", 1))).label("pr_count"),
|
||||||
|
func.count(
|
||||||
|
case((GithubItem.state.in_(["closed", "merged"]), 1))
|
||||||
|
).label("closed_count"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
GithubItem.github_updated_at >= since_dt,
|
||||||
|
GithubItem.github_updated_at <= until_dt,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
)
|
||||||
|
.group_by(group_col)
|
||||||
|
.order_by(desc("total"))
|
||||||
|
)
|
||||||
|
|
||||||
|
if repo:
|
||||||
|
agg_query = agg_query.filter(GithubItem.repo_path == repo)
|
||||||
|
|
||||||
|
groups = agg_query.all()
|
||||||
|
|
||||||
|
# Build summary with sample issues for each group
|
||||||
|
summary = []
|
||||||
|
total_issues = 0
|
||||||
|
total_prs = 0
|
||||||
|
|
||||||
|
for group_name, total, issue_count, pr_count, closed_count in groups:
|
||||||
|
if group_name is None:
|
||||||
|
group_name = "(unset)"
|
||||||
|
|
||||||
|
total_issues += issue_count
|
||||||
|
total_prs += pr_count
|
||||||
|
|
||||||
|
# Get sample issues for this group
|
||||||
|
sample_query = base_query.filter(group_col == group_name).limit(5)
|
||||||
|
samples = [
|
||||||
|
{
|
||||||
|
"number": item.number,
|
||||||
|
"title": item.title,
|
||||||
|
"repo_path": item.repo_path,
|
||||||
|
"state": item.state,
|
||||||
|
"url": _build_github_url(item.repo_path, item.number, item.kind),
|
||||||
|
}
|
||||||
|
for item in sample_query.all()
|
||||||
|
]
|
||||||
|
|
||||||
|
summary.append(
|
||||||
|
{
|
||||||
|
"group": group_name,
|
||||||
|
"total": total,
|
||||||
|
"issue_count": issue_count,
|
||||||
|
"pr_count": pr_count,
|
||||||
|
"closed_count": closed_count,
|
||||||
|
"issues": samples,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"period": {
|
||||||
|
"since": since_dt.isoformat(),
|
||||||
|
"until": until_dt.isoformat(),
|
||||||
|
},
|
||||||
|
"group_by": group_by,
|
||||||
|
"summary": summary,
|
||||||
|
"total_issues": total_issues,
|
||||||
|
"total_prs": total_prs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def github_repo_overview(
|
||||||
|
repo: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get an overview of a GitHub repository's issues and PRs.
|
||||||
|
Shows counts, status breakdown, top assignees, and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository path (e.g., "EquiStamp/equistamp" or "owner/name")
|
||||||
|
|
||||||
|
Returns: Repository statistics including counts, status breakdown, top assignees, labels
|
||||||
|
"""
|
||||||
|
logger.info(f"github_repo_overview called: repo={repo}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Base query for this repo
|
||||||
|
base_query = session.query(GithubItem).filter(
|
||||||
|
GithubItem.repo_path == repo,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get total counts
|
||||||
|
counts_query = session.query(
|
||||||
|
func.count(GithubItem.id).label("total"),
|
||||||
|
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
|
||||||
|
func.count(
|
||||||
|
case(((GithubItem.kind == "issue") & (GithubItem.state == "open"), 1))
|
||||||
|
).label("open_issues"),
|
||||||
|
func.count(
|
||||||
|
case(((GithubItem.kind == "issue") & (GithubItem.state == "closed"), 1))
|
||||||
|
).label("closed_issues"),
|
||||||
|
func.count(case((GithubItem.kind == "pr", 1))).label("total_prs"),
|
||||||
|
func.count(
|
||||||
|
case(((GithubItem.kind == "pr") & (GithubItem.state == "open"), 1))
|
||||||
|
).label("open_prs"),
|
||||||
|
func.count(
|
||||||
|
case(((GithubItem.kind == "pr") & (GithubItem.merged_at.isnot(None)), 1))
|
||||||
|
).label("merged_prs"),
|
||||||
|
func.max(GithubItem.github_updated_at).label("last_updated"),
|
||||||
|
).filter(
|
||||||
|
GithubItem.repo_path == repo,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
counts = counts_query.first()
|
||||||
|
|
||||||
|
# Status breakdown (for project_status)
|
||||||
|
status_query = (
|
||||||
|
session.query(
|
||||||
|
GithubItem.project_status.label("status"),
|
||||||
|
func.count(GithubItem.id).label("count"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
GithubItem.repo_path == repo,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
GithubItem.project_status.isnot(None),
|
||||||
|
)
|
||||||
|
.group_by(GithubItem.project_status)
|
||||||
|
.order_by(desc("count"))
|
||||||
|
)
|
||||||
|
|
||||||
|
status_breakdown = {row.status: row.count for row in status_query.all()}
|
||||||
|
|
||||||
|
# Top assignees (open issues only)
|
||||||
|
assignee_query = (
|
||||||
|
session.query(
|
||||||
|
func.unnest(GithubItem.assignees).label("assignee"),
|
||||||
|
func.count(GithubItem.id).label("count"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
GithubItem.repo_path == repo,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
GithubItem.state == "open",
|
||||||
|
)
|
||||||
|
.group_by("assignee")
|
||||||
|
.order_by(desc("count"))
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
|
||||||
|
top_assignees = [
|
||||||
|
{"username": row.assignee, "open_count": row.count}
|
||||||
|
for row in assignee_query.all()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Label counts
|
||||||
|
label_query = (
|
||||||
|
session.query(
|
||||||
|
func.unnest(GithubItem.labels).label("label"),
|
||||||
|
func.count(GithubItem.id).label("count"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
GithubItem.repo_path == repo,
|
||||||
|
GithubItem.kind.in_(["issue", "pr"]),
|
||||||
|
)
|
||||||
|
.group_by("label")
|
||||||
|
.order_by(desc("count"))
|
||||||
|
.limit(20)
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = {row.label: row.count for row in label_query.all()}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"repo_path": repo,
|
||||||
|
"counts": {
|
||||||
|
"total": counts.total if counts else 0,
|
||||||
|
"total_issues": counts.total_issues if counts else 0,
|
||||||
|
"open_issues": counts.open_issues if counts else 0,
|
||||||
|
"closed_issues": counts.closed_issues if counts else 0,
|
||||||
|
"total_prs": counts.total_prs if counts else 0,
|
||||||
|
"open_prs": counts.open_prs if counts else 0,
|
||||||
|
"merged_prs": counts.merged_prs if counts else 0,
|
||||||
|
},
|
||||||
|
"status_breakdown": status_breakdown,
|
||||||
|
"top_assignees": top_assignees,
|
||||||
|
"labels": labels,
|
||||||
|
"last_updated": (
|
||||||
|
counts.last_updated.isoformat()
|
||||||
|
if counts and counts.last_updated
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
257
src/memory/api/MCP/people.py
Normal file
257
src/memory/api/MCP/people.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
"""
|
||||||
|
MCP tools for tracking people.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy import cast as sql_cast
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
|
from memory.api.MCP.base import mcp
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import Person
|
||||||
|
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
|
||||||
|
from memory.common.celery_app import app as celery_app
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _person_to_dict(person: Person) -> dict[str, Any]:
|
||||||
|
"""Convert a Person model to a dictionary for API responses."""
|
||||||
|
return {
|
||||||
|
"identifier": person.identifier,
|
||||||
|
"display_name": person.display_name,
|
||||||
|
"aliases": list(person.aliases or []),
|
||||||
|
"contact_info": dict(person.contact_info or {}),
|
||||||
|
"tags": list(person.tags or []),
|
||||||
|
"notes": person.content,
|
||||||
|
"created_at": person.inserted_at.isoformat() if person.inserted_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def add_person(
|
||||||
|
identifier: str,
|
||||||
|
display_name: str,
|
||||||
|
aliases: list[str] | None = None,
|
||||||
|
contact_info: dict | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Add a new person to track.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: Unique slug for the person (e.g., "alice_chen")
|
||||||
|
display_name: Human-readable name (e.g., "Alice Chen")
|
||||||
|
aliases: Alternative names/handles (e.g., ["@alice_c", "alice.chen@work.com"])
|
||||||
|
contact_info: Contact information as a dict (e.g., {"email": "...", "phone": "..."})
|
||||||
|
tags: Categorization tags (e.g., ["work", "friend", "climbing"])
|
||||||
|
notes: Free-form notes about the person
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task status with task_id
|
||||||
|
|
||||||
|
Example:
|
||||||
|
add_person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
display_name="Alice Chen",
|
||||||
|
aliases=["@alice_c"],
|
||||||
|
contact_info={"email": "alice@example.com"},
|
||||||
|
tags=["work", "engineering"],
|
||||||
|
notes="Tech lead on Platform team"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP: Adding person: {identifier}")
|
||||||
|
|
||||||
|
# Check if person already exists
|
||||||
|
with make_session() as session:
|
||||||
|
existing = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"Person with identifier '{identifier}' already exists")
|
||||||
|
|
||||||
|
task = celery_app.send_task(
|
||||||
|
SYNC_PERSON,
|
||||||
|
queue=f"{settings.CELERY_QUEUE_PREFIX}-people",
|
||||||
|
kwargs={
|
||||||
|
"identifier": identifier,
|
||||||
|
"display_name": display_name,
|
||||||
|
"aliases": aliases,
|
||||||
|
"contact_info": contact_info,
|
||||||
|
"tags": tags,
|
||||||
|
"notes": notes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": "queued",
|
||||||
|
"identifier": identifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def update_person_info(
|
||||||
|
identifier: str,
|
||||||
|
display_name: str | None = None,
|
||||||
|
aliases: list[str] | None = None,
|
||||||
|
contact_info: dict | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
replace_notes: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Update information about a person with merge semantics.
|
||||||
|
|
||||||
|
This tool MERGES new information with existing data rather than replacing it:
|
||||||
|
- display_name: Replaces existing value
|
||||||
|
- aliases: Adds new aliases (union with existing)
|
||||||
|
- contact_info: Deep merges (adds new keys, updates existing keys, never deletes)
|
||||||
|
- tags: Adds new tags (union with existing)
|
||||||
|
- notes: Appends to existing notes (or replaces if replace_notes=True)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: The person's unique identifier
|
||||||
|
display_name: New display name (replaces existing)
|
||||||
|
aliases: Additional aliases to add
|
||||||
|
contact_info: Additional contact info to merge
|
||||||
|
tags: Additional tags to add
|
||||||
|
notes: Notes to append (or replace if replace_notes=True)
|
||||||
|
replace_notes: If True, replace notes instead of appending
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task status with task_id
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Add new contact info without losing existing data
|
||||||
|
update_person_info(
|
||||||
|
identifier="alice_chen",
|
||||||
|
contact_info={"phone": "555-1234"}, # Added to existing
|
||||||
|
notes="Enjoys rock climbing" # Appended to existing notes
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP: Updating person: {identifier}")
|
||||||
|
|
||||||
|
# Verify person exists
|
||||||
|
with make_session() as session:
|
||||||
|
person = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
|
if not person:
|
||||||
|
raise ValueError(f"Person with identifier '{identifier}' not found")
|
||||||
|
|
||||||
|
task = celery_app.send_task(
|
||||||
|
UPDATE_PERSON,
|
||||||
|
queue=f"{settings.CELERY_QUEUE_PREFIX}-people",
|
||||||
|
kwargs={
|
||||||
|
"identifier": identifier,
|
||||||
|
"display_name": display_name,
|
||||||
|
"aliases": aliases,
|
||||||
|
"contact_info": contact_info,
|
||||||
|
"tags": tags,
|
||||||
|
"notes": notes,
|
||||||
|
"replace_notes": replace_notes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": "queued",
|
||||||
|
"identifier": identifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_person(identifier: str) -> dict | None:
|
||||||
|
"""
|
||||||
|
Get a person by their identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: The person's unique identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The person record, or None if not found
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP: Getting person: {identifier}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
person = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
|
if not person:
|
||||||
|
return None
|
||||||
|
return _person_to_dict(person)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def list_people(
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
List all tracked people, optionally filtered by tags or search term.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tags: Filter to people with at least one of these tags
|
||||||
|
search: Search term to match against name, aliases, or notes
|
||||||
|
limit: Maximum number of results (default 50, max 200)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of person records matching the filters
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP: Listing people (tags={tags}, search={search})")
|
||||||
|
|
||||||
|
limit = min(limit, 200)
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
query = session.query(Person)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
query = query.filter(
|
||||||
|
Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
||||||
|
)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
search_term = f"%{search.lower()}%"
|
||||||
|
query = query.filter(
|
||||||
|
(Person.display_name.ilike(search_term))
|
||||||
|
| (Person.content.ilike(search_term))
|
||||||
|
| (Person.identifier.ilike(search_term))
|
||||||
|
)
|
||||||
|
|
||||||
|
query = query.order_by(Person.display_name).limit(limit)
|
||||||
|
people = query.all()
|
||||||
|
|
||||||
|
return [_person_to_dict(p) for p in people]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def delete_person(identifier: str) -> dict:
|
||||||
|
"""
|
||||||
|
Delete a person by their identifier.
|
||||||
|
|
||||||
|
This permanently removes the person and all associated data.
|
||||||
|
Observations about this person (with subject "person:<identifier>") will remain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: The person's unique identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Confirmation of deletion
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP: Deleting person: {identifier}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
person = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
|
if not person:
|
||||||
|
raise ValueError(f"Person with identifier '{identifier}' not found")
|
||||||
|
|
||||||
|
display_name = person.display_name
|
||||||
|
session.delete(person)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"deleted": True,
|
||||||
|
"identifier": identifier,
|
||||||
|
"display_name": display_name,
|
||||||
|
}
|
||||||
@ -16,6 +16,7 @@ SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
|||||||
DISCORD_ROOT = "memory.workers.tasks.discord"
|
DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||||
BACKUP_ROOT = "memory.workers.tasks.backup"
|
BACKUP_ROOT = "memory.workers.tasks.backup"
|
||||||
GITHUB_ROOT = "memory.workers.tasks.github"
|
GITHUB_ROOT = "memory.workers.tasks.github"
|
||||||
|
PEOPLE_ROOT = "memory.workers.tasks.people"
|
||||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||||
@ -67,6 +68,10 @@ SYNC_GITHUB_REPO = f"{GITHUB_ROOT}.sync_github_repo"
|
|||||||
SYNC_ALL_GITHUB_REPOS = f"{GITHUB_ROOT}.sync_all_github_repos"
|
SYNC_ALL_GITHUB_REPOS = f"{GITHUB_ROOT}.sync_all_github_repos"
|
||||||
SYNC_GITHUB_ITEM = f"{GITHUB_ROOT}.sync_github_item"
|
SYNC_GITHUB_ITEM = f"{GITHUB_ROOT}.sync_github_item"
|
||||||
|
|
||||||
|
# People tasks
|
||||||
|
SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
|
||||||
|
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
|
||||||
|
|
||||||
|
|
||||||
def get_broker_url() -> str:
|
def get_broker_url() -> str:
|
||||||
protocol = settings.CELERY_BROKER_TYPE
|
protocol = settings.CELERY_BROKER_TYPE
|
||||||
@ -123,6 +128,7 @@ app.conf.update(
|
|||||||
},
|
},
|
||||||
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
||||||
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
|
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
|
||||||
|
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
|
||||||
},
|
},
|
||||||
beat_schedule={
|
beat_schedule={
|
||||||
"sync-github-repos-hourly": {
|
"sync-github-repos-hourly": {
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from memory.common.db.models.source_items import (
|
|||||||
BookSection,
|
BookSection,
|
||||||
ForumPost,
|
ForumPost,
|
||||||
GithubItem,
|
GithubItem,
|
||||||
|
GithubPRData,
|
||||||
GitCommit,
|
GitCommit,
|
||||||
Photo,
|
Photo,
|
||||||
MiscDoc,
|
MiscDoc,
|
||||||
@ -46,6 +47,10 @@ from memory.common.db.models.observations import (
|
|||||||
BeliefCluster,
|
BeliefCluster,
|
||||||
ConversationMetrics,
|
ConversationMetrics,
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.people import (
|
||||||
|
Person,
|
||||||
|
PersonPayload,
|
||||||
|
)
|
||||||
from memory.common.db.models.sources import (
|
from memory.common.db.models.sources import (
|
||||||
Book,
|
Book,
|
||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
@ -77,6 +82,7 @@ Payload = (
|
|||||||
| ForumPostPayload
|
| ForumPostPayload
|
||||||
| EmailAttachmentPayload
|
| EmailAttachmentPayload
|
||||||
| MailMessagePayload
|
| MailMessagePayload
|
||||||
|
| PersonPayload
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -95,6 +101,7 @@ __all__ = [
|
|||||||
"BookSection",
|
"BookSection",
|
||||||
"ForumPost",
|
"ForumPost",
|
||||||
"GithubItem",
|
"GithubItem",
|
||||||
|
"GithubPRData",
|
||||||
"GitCommit",
|
"GitCommit",
|
||||||
"Photo",
|
"Photo",
|
||||||
"MiscDoc",
|
"MiscDoc",
|
||||||
@ -105,6 +112,9 @@ __all__ = [
|
|||||||
"ObservationPattern",
|
"ObservationPattern",
|
||||||
"BeliefCluster",
|
"BeliefCluster",
|
||||||
"ConversationMetrics",
|
"ConversationMetrics",
|
||||||
|
# People
|
||||||
|
"Person",
|
||||||
|
"PersonPayload",
|
||||||
# Sources
|
# Sources
|
||||||
"Book",
|
"Book",
|
||||||
"ArticleFeed",
|
"ArticleFeed",
|
||||||
|
|||||||
95
src/memory/common/db/models/people.py
Normal file
95
src/memory/common/db/models/people.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""
|
||||||
|
Database models for tracking people.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Annotated, Sequence, cast
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
import memory.common.extract as extract
|
||||||
|
|
||||||
|
from memory.common.db.models.source_item import (
|
||||||
|
SourceItem,
|
||||||
|
SourceItemPayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonPayload(SourceItemPayload):
|
||||||
|
identifier: Annotated[str, "Unique identifier/slug for the person"]
|
||||||
|
display_name: Annotated[str, "Display name of the person"]
|
||||||
|
aliases: Annotated[list[str], "Alternative names/handles for the person"]
|
||||||
|
contact_info: Annotated[dict, "Contact information (email, phone, etc.)"]
|
||||||
|
|
||||||
|
|
||||||
|
class Person(SourceItem):
|
||||||
|
"""A person you know or want to track."""
|
||||||
|
|
||||||
|
__tablename__ = "people"
|
||||||
|
|
||||||
|
id = Column(
|
||||||
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
identifier = Column(Text, unique=True, nullable=False, index=True)
|
||||||
|
display_name = Column(Text, nullable=False)
|
||||||
|
aliases = Column(ARRAY(Text), server_default="{}", nullable=False)
|
||||||
|
contact_info = Column(JSONB, server_default="{}", nullable=False)
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "person",
|
||||||
|
}
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("person_identifier_idx", "identifier"),
|
||||||
|
Index("person_display_name_idx", "display_name"),
|
||||||
|
Index("person_aliases_idx", "aliases", postgresql_using="gin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_payload(self) -> PersonPayload:
|
||||||
|
return PersonPayload(
|
||||||
|
**super().as_payload(),
|
||||||
|
identifier=cast(str, self.identifier),
|
||||||
|
display_name=cast(str, self.display_name),
|
||||||
|
aliases=cast(list[str], self.aliases) or [],
|
||||||
|
contact_info=cast(dict, self.contact_info) or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_contents(self) -> dict:
|
||||||
|
return {
|
||||||
|
"identifier": self.identifier,
|
||||||
|
"display_name": self.display_name,
|
||||||
|
"aliases": self.aliases,
|
||||||
|
"contact_info": self.contact_info,
|
||||||
|
"notes": self.content,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
|
"""Create searchable chunks from person data."""
|
||||||
|
parts = [f"# {self.display_name}"]
|
||||||
|
|
||||||
|
if self.aliases:
|
||||||
|
aliases_str = ", ".join(cast(list[str], self.aliases))
|
||||||
|
parts.append(f"Also known as: {aliases_str}")
|
||||||
|
|
||||||
|
if self.tags:
|
||||||
|
tags_str = ", ".join(cast(list[str], self.tags))
|
||||||
|
parts.append(f"Tags: {tags_str}")
|
||||||
|
|
||||||
|
if self.content:
|
||||||
|
parts.append(f"\n{self.content}")
|
||||||
|
|
||||||
|
text = "\n".join(parts)
|
||||||
|
return extract.extract_text(text, modality="person")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collections(cls) -> list[str]:
|
||||||
|
return ["person"]
|
||||||
@ -9,15 +9,19 @@ from collections.abc import Collection
|
|||||||
from typing import Any, Annotated, Sequence, cast
|
from typing import Any, Annotated, Sequence, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import zlib
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
ARRAY,
|
ARRAY,
|
||||||
BigInteger,
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
CheckConstraint,
|
CheckConstraint,
|
||||||
Column,
|
Column,
|
||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Index,
|
Index,
|
||||||
Integer,
|
Integer,
|
||||||
|
LargeBinary,
|
||||||
Numeric,
|
Numeric,
|
||||||
Text,
|
Text,
|
||||||
func,
|
func,
|
||||||
@ -30,6 +34,7 @@ import memory.common.extract as extract
|
|||||||
import memory.common.summarizer as summarizer
|
import memory.common.summarizer as summarizer
|
||||||
import memory.common.formatters.observation as observation
|
import memory.common.formatters.observation as observation
|
||||||
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
from memory.common.db.models.source_item import (
|
from memory.common.db.models.source_item import (
|
||||||
SourceItem,
|
SourceItem,
|
||||||
SourceItemPayload,
|
SourceItemPayload,
|
||||||
@ -858,6 +863,14 @@ class GithubItem(SourceItem):
|
|||||||
milestone = Column(Text, nullable=True)
|
milestone = Column(Text, nullable=True)
|
||||||
comment_count = Column(Integer, nullable=True)
|
comment_count = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Relationship to PR-specific data
|
||||||
|
pr_data = relationship(
|
||||||
|
"GithubPRData",
|
||||||
|
back_populates="github_item",
|
||||||
|
uselist=False,
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "github_item",
|
"polymorphic_identity": "github_item",
|
||||||
}
|
}
|
||||||
@ -902,6 +915,59 @@ class GithubItem(SourceItem):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class GithubPRData(Base):
|
||||||
|
"""PR-specific data linked to GithubItem. Not a SourceItem - not indexed separately."""
|
||||||
|
|
||||||
|
__tablename__ = "github_pr_data"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
github_item_id = Column(
|
||||||
|
BigInteger,
|
||||||
|
ForeignKey("github_item.id", ondelete="CASCADE"),
|
||||||
|
unique=True,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Diff (compressed with zlib)
|
||||||
|
diff_compressed = Column(LargeBinary, nullable=True)
|
||||||
|
|
||||||
|
# File changes as structured data
|
||||||
|
# [{filename, status, additions, deletions, patch?}]
|
||||||
|
files = Column(JSONB, nullable=True)
|
||||||
|
|
||||||
|
# Stats
|
||||||
|
additions = Column(Integer, nullable=True)
|
||||||
|
deletions = Column(Integer, nullable=True)
|
||||||
|
changed_files_count = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Reviews (structured)
|
||||||
|
# [{user, state, body, submitted_at}]
|
||||||
|
reviews = Column(JSONB, nullable=True)
|
||||||
|
|
||||||
|
# Review comments (line-by-line code comments)
|
||||||
|
# [{user, body, path, line, diff_hunk, created_at}]
|
||||||
|
review_comments = Column(JSONB, nullable=True)
|
||||||
|
|
||||||
|
# Relationship back to GithubItem
|
||||||
|
github_item = relationship("GithubItem", back_populates="pr_data")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def diff(self) -> str | None:
|
||||||
|
"""Decompress and return the full diff text."""
|
||||||
|
if self.diff_compressed:
|
||||||
|
return zlib.decompress(self.diff_compressed).decode("utf-8")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@diff.setter
|
||||||
|
def diff(self, value: str | None) -> None:
|
||||||
|
"""Compress and store the diff text."""
|
||||||
|
if value:
|
||||||
|
self.diff_compressed = zlib.compress(value.encode("utf-8"))
|
||||||
|
else:
|
||||||
|
self.diff_compressed = None
|
||||||
|
|
||||||
|
|
||||||
class NotePayload(SourceItemPayload):
|
class NotePayload(SourceItemPayload):
|
||||||
note_type: Annotated[str | None, "Category of the note"]
|
note_type: Annotated[str | None, "Category of the note"]
|
||||||
subject: Annotated[str | None, "What the note is about"]
|
subject: Annotated[str | None, "What the note is about"]
|
||||||
|
|||||||
@ -42,6 +42,51 @@ class GithubComment(TypedDict):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class GithubReviewComment(TypedDict):
|
||||||
|
"""A line-by-line code review comment on a PR."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
user: str
|
||||||
|
body: str
|
||||||
|
path: str
|
||||||
|
line: int | None
|
||||||
|
side: str # "LEFT" or "RIGHT"
|
||||||
|
diff_hunk: str
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class GithubReview(TypedDict):
|
||||||
|
"""A PR review (approval, request changes, etc.)."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
user: str
|
||||||
|
state: str # "approved", "changes_requested", "commented", "dismissed"
|
||||||
|
body: str | None
|
||||||
|
submitted_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class GithubFileChange(TypedDict):
|
||||||
|
"""A file changed in a PR."""
|
||||||
|
|
||||||
|
filename: str
|
||||||
|
status: str # "added", "modified", "removed", "renamed"
|
||||||
|
additions: int
|
||||||
|
deletions: int
|
||||||
|
patch: str | None # Diff patch for this file
|
||||||
|
|
||||||
|
|
||||||
|
class GithubPRDataDict(TypedDict):
|
||||||
|
"""PR-specific data for storage in GithubPRData model."""
|
||||||
|
|
||||||
|
diff: str | None # Full diff text
|
||||||
|
files: list[GithubFileChange]
|
||||||
|
additions: int
|
||||||
|
deletions: int
|
||||||
|
changed_files_count: int
|
||||||
|
reviews: list[GithubReview]
|
||||||
|
review_comments: list[GithubReviewComment]
|
||||||
|
|
||||||
|
|
||||||
class GithubIssueData(TypedDict):
|
class GithubIssueData(TypedDict):
|
||||||
"""Parsed issue/PR data ready for storage."""
|
"""Parsed issue/PR data ready for storage."""
|
||||||
|
|
||||||
@ -60,9 +105,11 @@ class GithubIssueData(TypedDict):
|
|||||||
github_updated_at: datetime
|
github_updated_at: datetime
|
||||||
comment_count: int
|
comment_count: int
|
||||||
comments: list[GithubComment]
|
comments: list[GithubComment]
|
||||||
diff_summary: str | None # PRs only
|
diff_summary: str | None # PRs only (truncated, for backward compat)
|
||||||
project_fields: dict[str, Any] | None
|
project_fields: dict[str, Any] | None
|
||||||
content_hash: str
|
content_hash: str
|
||||||
|
# PR-specific extended data (None for issues)
|
||||||
|
pr_data: GithubPRDataDict | None
|
||||||
|
|
||||||
|
|
||||||
def parse_github_date(date_str: str | None) -> datetime | None:
|
def parse_github_date(date_str: str | None) -> datetime | None:
|
||||||
@ -267,6 +314,145 @@ class GithubClient:
|
|||||||
|
|
||||||
return comments
|
return comments
|
||||||
|
|
||||||
|
def fetch_review_comments(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> list[GithubReviewComment]:
|
||||||
|
"""Fetch all line-by-line review comments for a PR."""
|
||||||
|
comments: list[GithubReviewComment] = []
|
||||||
|
page = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
response = self.session.get(
|
||||||
|
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/comments",
|
||||||
|
params={"page": page, "per_page": 100},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
self._handle_rate_limit(response)
|
||||||
|
|
||||||
|
page_comments = response.json()
|
||||||
|
if not page_comments:
|
||||||
|
break
|
||||||
|
|
||||||
|
comments.extend(
|
||||||
|
[
|
||||||
|
GithubReviewComment(
|
||||||
|
id=c["id"],
|
||||||
|
user=c["user"]["login"] if c.get("user") else "ghost",
|
||||||
|
body=c.get("body", ""),
|
||||||
|
path=c.get("path", ""),
|
||||||
|
line=c.get("line"),
|
||||||
|
side=c.get("side", "RIGHT"),
|
||||||
|
diff_hunk=c.get("diff_hunk", ""),
|
||||||
|
created_at=c["created_at"],
|
||||||
|
)
|
||||||
|
for c in page_comments
|
||||||
|
]
|
||||||
|
)
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
return comments
|
||||||
|
|
||||||
|
def fetch_reviews(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> list[GithubReview]:
|
||||||
|
"""Fetch all reviews (approvals, change requests) for a PR."""
|
||||||
|
reviews: list[GithubReview] = []
|
||||||
|
page = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
response = self.session.get(
|
||||||
|
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/reviews",
|
||||||
|
params={"page": page, "per_page": 100},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
self._handle_rate_limit(response)
|
||||||
|
|
||||||
|
page_reviews = response.json()
|
||||||
|
if not page_reviews:
|
||||||
|
break
|
||||||
|
|
||||||
|
reviews.extend(
|
||||||
|
[
|
||||||
|
GithubReview(
|
||||||
|
id=r["id"],
|
||||||
|
user=r["user"]["login"] if r.get("user") else "ghost",
|
||||||
|
state=r.get("state", "COMMENTED").lower(),
|
||||||
|
body=r.get("body"),
|
||||||
|
submitted_at=r.get("submitted_at", ""),
|
||||||
|
)
|
||||||
|
for r in page_reviews
|
||||||
|
]
|
||||||
|
)
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
return reviews
|
||||||
|
|
||||||
|
def fetch_pr_files(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> list[GithubFileChange]:
|
||||||
|
"""Fetch list of files changed in a PR with patches."""
|
||||||
|
files: list[GithubFileChange] = []
|
||||||
|
page = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
response = self.session.get(
|
||||||
|
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/files",
|
||||||
|
params={"page": page, "per_page": 100},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
self._handle_rate_limit(response)
|
||||||
|
|
||||||
|
page_files = response.json()
|
||||||
|
if not page_files:
|
||||||
|
break
|
||||||
|
|
||||||
|
files.extend(
|
||||||
|
[
|
||||||
|
GithubFileChange(
|
||||||
|
filename=f["filename"],
|
||||||
|
status=f.get("status", "modified"),
|
||||||
|
additions=f.get("additions", 0),
|
||||||
|
deletions=f.get("deletions", 0),
|
||||||
|
patch=f.get("patch"), # May be None for binary files
|
||||||
|
)
|
||||||
|
for f in page_files
|
||||||
|
]
|
||||||
|
)
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
def fetch_pr_diff(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> str | None:
|
||||||
|
"""Fetch the full diff for a PR (not truncated)."""
|
||||||
|
try:
|
||||||
|
response = self.session.get(
|
||||||
|
f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||||
|
headers={"Accept": "application/vnd.github.diff"},
|
||||||
|
timeout=60, # Longer timeout for large diffs
|
||||||
|
)
|
||||||
|
if response.ok:
|
||||||
|
return response.text
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch PR diff: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def fetch_project_fields(
|
def fetch_project_fields(
|
||||||
self,
|
self,
|
||||||
owner: str,
|
owner: str,
|
||||||
@ -490,28 +676,44 @@ class GithubClient:
|
|||||||
diff_summary=None,
|
diff_summary=None,
|
||||||
project_fields=None, # Fetched separately if enabled
|
project_fields=None, # Fetched separately if enabled
|
||||||
content_hash=compute_content_hash(body, comments),
|
content_hash=compute_content_hash(body, comments),
|
||||||
|
pr_data=None, # Issues don't have PR data
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_pr(
|
def _parse_pr(
|
||||||
self, owner: str, repo: str, pr: dict[str, Any]
|
self, owner: str, repo: str, pr: dict[str, Any]
|
||||||
) -> GithubIssueData:
|
) -> GithubIssueData:
|
||||||
"""Parse raw PR data into structured format."""
|
"""Parse raw PR data into structured format."""
|
||||||
comments = self.fetch_comments(owner, repo, pr["number"])
|
pr_number = pr["number"]
|
||||||
|
comments = self.fetch_comments(owner, repo, pr_number)
|
||||||
body = pr.get("body") or ""
|
body = pr.get("body") or ""
|
||||||
|
|
||||||
# Get diff summary (truncated)
|
# Fetch PR-specific data
|
||||||
diff_summary = None
|
review_comments = self.fetch_review_comments(owner, repo, pr_number)
|
||||||
if diff_url := pr.get("diff_url"):
|
reviews = self.fetch_reviews(owner, repo, pr_number)
|
||||||
try:
|
files = self.fetch_pr_files(owner, repo, pr_number)
|
||||||
diff_response = self.session.get(diff_url, timeout=30)
|
full_diff = self.fetch_pr_diff(owner, repo, pr_number)
|
||||||
if diff_response.ok:
|
|
||||||
diff_summary = diff_response.text[:5000] # Truncate large diffs
|
# Calculate stats from files
|
||||||
except Exception as e:
|
additions = sum(f["additions"] for f in files)
|
||||||
logger.warning(f"Failed to fetch diff: {e}")
|
deletions = sum(f["deletions"] for f in files)
|
||||||
|
|
||||||
|
# Get diff summary (truncated, for backward compatibility)
|
||||||
|
diff_summary = full_diff[:5000] if full_diff else None
|
||||||
|
|
||||||
|
# Build PR data dict
|
||||||
|
pr_data = GithubPRDataDict(
|
||||||
|
diff=full_diff,
|
||||||
|
files=files,
|
||||||
|
additions=additions,
|
||||||
|
deletions=deletions,
|
||||||
|
changed_files_count=len(files),
|
||||||
|
reviews=reviews,
|
||||||
|
review_comments=review_comments,
|
||||||
|
)
|
||||||
|
|
||||||
return GithubIssueData(
|
return GithubIssueData(
|
||||||
kind="pr",
|
kind="pr",
|
||||||
number=pr["number"],
|
number=pr_number,
|
||||||
title=pr["title"],
|
title=pr["title"],
|
||||||
body=body,
|
body=body,
|
||||||
state=pr["state"],
|
state=pr["state"],
|
||||||
@ -528,4 +730,5 @@ class GithubClient:
|
|||||||
diff_summary=diff_summary,
|
diff_summary=diff_summary,
|
||||||
project_fields=None, # Fetched separately if enabled
|
project_fields=None, # Fetched separately if enabled
|
||||||
content_hash=compute_content_hash(body, comments),
|
content_hash=compute_content_hash(body, comments),
|
||||||
|
pr_data=pr_data,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -12,12 +12,13 @@ from memory.common.celery_app import (
|
|||||||
SYNC_GITHUB_ITEM,
|
SYNC_GITHUB_ITEM,
|
||||||
)
|
)
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import GithubItem
|
from memory.common.db.models import GithubItem, GithubPRData
|
||||||
from memory.common.db.models.sources import GithubAccount, GithubRepo
|
from memory.common.db.models.sources import GithubAccount, GithubRepo
|
||||||
from memory.parsers.github import (
|
from memory.parsers.github import (
|
||||||
GithubClient,
|
GithubClient,
|
||||||
GithubCredentials,
|
GithubCredentials,
|
||||||
GithubIssueData,
|
GithubIssueData,
|
||||||
|
GithubPRDataDict,
|
||||||
)
|
)
|
||||||
from memory.workers.tasks.content_processing import (
|
from memory.workers.tasks.content_processing import (
|
||||||
create_content_hash,
|
create_content_hash,
|
||||||
@ -32,11 +33,42 @@ logger = logging.getLogger(__name__)
|
|||||||
def _build_content(issue_data: GithubIssueData) -> str:
|
def _build_content(issue_data: GithubIssueData) -> str:
|
||||||
"""Build searchable content from issue/PR data."""
|
"""Build searchable content from issue/PR data."""
|
||||||
content_parts = [f"# {issue_data['title']}", issue_data["body"]]
|
content_parts = [f"# {issue_data['title']}", issue_data["body"]]
|
||||||
|
|
||||||
|
# Add regular comments
|
||||||
for comment in issue_data["comments"]:
|
for comment in issue_data["comments"]:
|
||||||
content_parts.append(f"\n---\n**{comment['author']}**: {comment['body']}")
|
content_parts.append(f"\n---\n**{comment['author']}**: {comment['body']}")
|
||||||
|
|
||||||
|
# Add review comments for PRs (makes them searchable)
|
||||||
|
pr_data = issue_data.get("pr_data")
|
||||||
|
if pr_data and pr_data.get("review_comments"):
|
||||||
|
content_parts.append("\n---\n## Code Review Comments\n")
|
||||||
|
for rc in pr_data["review_comments"]:
|
||||||
|
content_parts.append(
|
||||||
|
f"**{rc['user']}** on `{rc['path']}`: {rc['body']}"
|
||||||
|
)
|
||||||
|
|
||||||
return "\n\n".join(content_parts)
|
return "\n\n".join(content_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pr_data(issue_data: GithubIssueData) -> GithubPRData | None:
|
||||||
|
"""Create GithubPRData from PR-specific data if available."""
|
||||||
|
pr_data_dict = issue_data.get("pr_data")
|
||||||
|
if not pr_data_dict:
|
||||||
|
return None
|
||||||
|
|
||||||
|
pr_data = GithubPRData(
|
||||||
|
additions=pr_data_dict.get("additions"),
|
||||||
|
deletions=pr_data_dict.get("deletions"),
|
||||||
|
changed_files_count=pr_data_dict.get("changed_files_count"),
|
||||||
|
files=pr_data_dict.get("files"),
|
||||||
|
reviews=pr_data_dict.get("reviews"),
|
||||||
|
review_comments=pr_data_dict.get("review_comments"),
|
||||||
|
)
|
||||||
|
# Use the setter to compress the diff
|
||||||
|
pr_data.diff = pr_data_dict.get("diff")
|
||||||
|
return pr_data
|
||||||
|
|
||||||
|
|
||||||
def _create_github_item(
|
def _create_github_item(
|
||||||
repo: GithubRepo,
|
repo: GithubRepo,
|
||||||
issue_data: GithubIssueData,
|
issue_data: GithubIssueData,
|
||||||
@ -57,7 +89,7 @@ def _create_github_item(
|
|||||||
|
|
||||||
repo_tags = cast(list[str], repo.tags) or []
|
repo_tags = cast(list[str], repo.tags) or []
|
||||||
|
|
||||||
return GithubItem(
|
github_item = GithubItem(
|
||||||
modality="github",
|
modality="github",
|
||||||
sha256=create_content_hash(content),
|
sha256=create_content_hash(content),
|
||||||
content=content,
|
content=content,
|
||||||
@ -86,6 +118,12 @@ def _create_github_item(
|
|||||||
mime_type="text/markdown",
|
mime_type="text/markdown",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create PR data if this is a PR
|
||||||
|
if issue_data["kind"] == "pr":
|
||||||
|
github_item.pr_data = _create_pr_data(issue_data)
|
||||||
|
|
||||||
|
return github_item
|
||||||
|
|
||||||
|
|
||||||
def _needs_reindex(existing: GithubItem, new_data: GithubIssueData) -> bool:
|
def _needs_reindex(existing: GithubItem, new_data: GithubIssueData) -> bool:
|
||||||
"""Check if an existing item needs reindexing based on content changes."""
|
"""Check if an existing item needs reindexing based on content changes."""
|
||||||
@ -165,6 +203,23 @@ def _update_existing_item(
|
|||||||
repo_tags = cast(list[str], repo.tags) or []
|
repo_tags = cast(list[str], repo.tags) or []
|
||||||
existing.tags = repo_tags + issue_data["labels"] # type: ignore
|
existing.tags = repo_tags + issue_data["labels"] # type: ignore
|
||||||
|
|
||||||
|
# Update PR data if this is a PR
|
||||||
|
if issue_data["kind"] == "pr":
|
||||||
|
pr_data_dict = issue_data.get("pr_data")
|
||||||
|
if pr_data_dict:
|
||||||
|
if existing.pr_data:
|
||||||
|
# Update existing pr_data
|
||||||
|
existing.pr_data.additions = pr_data_dict.get("additions")
|
||||||
|
existing.pr_data.deletions = pr_data_dict.get("deletions")
|
||||||
|
existing.pr_data.changed_files_count = pr_data_dict.get("changed_files_count")
|
||||||
|
existing.pr_data.files = pr_data_dict.get("files")
|
||||||
|
existing.pr_data.reviews = pr_data_dict.get("reviews")
|
||||||
|
existing.pr_data.review_comments = pr_data_dict.get("review_comments")
|
||||||
|
existing.pr_data.diff = pr_data_dict.get("diff")
|
||||||
|
else:
|
||||||
|
# Create new pr_data
|
||||||
|
existing.pr_data = _create_pr_data(issue_data)
|
||||||
|
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
# Re-embed and push to Qdrant
|
# Re-embed and push to Qdrant
|
||||||
@ -193,6 +248,8 @@ def _serialize_issue_data(data: GithubIssueData) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
for c in data["comments"]
|
for c in data["comments"]
|
||||||
],
|
],
|
||||||
|
# pr_data is already JSON-serializable (TypedDict)
|
||||||
|
"pr_data": data.get("pr_data"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -200,6 +257,11 @@ def _deserialize_issue_data(data: dict[str, Any]) -> GithubIssueData:
|
|||||||
"""Deserialize issue data from Celery task."""
|
"""Deserialize issue data from Celery task."""
|
||||||
from memory.parsers.github import parse_github_date
|
from memory.parsers.github import parse_github_date
|
||||||
|
|
||||||
|
# Reconstruct pr_data if present
|
||||||
|
pr_data: GithubPRDataDict | None = None
|
||||||
|
if data.get("pr_data"):
|
||||||
|
pr_data = cast(GithubPRDataDict, data["pr_data"])
|
||||||
|
|
||||||
return GithubIssueData(
|
return GithubIssueData(
|
||||||
kind=data["kind"],
|
kind=data["kind"],
|
||||||
number=data["number"],
|
number=data["number"],
|
||||||
@ -219,6 +281,7 @@ def _deserialize_issue_data(data: dict[str, Any]) -> GithubIssueData:
|
|||||||
diff_summary=data.get("diff_summary"),
|
diff_summary=data.get("diff_summary"),
|
||||||
project_fields=data.get("project_fields"),
|
project_fields=data.get("project_fields"),
|
||||||
content_hash=data["content_hash"],
|
content_hash=data["content_hash"],
|
||||||
|
pr_data=pr_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
145
src/memory/workers/tasks/people.py
Normal file
145
src/memory/workers/tasks/people.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Celery tasks for tracking people.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import Person
|
||||||
|
from memory.common.celery_app import app, SYNC_PERSON, UPDATE_PERSON
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
check_content_exists,
|
||||||
|
create_content_hash,
|
||||||
|
create_task_result,
|
||||||
|
process_content_item,
|
||||||
|
safe_task_execution,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _deep_merge(base: dict, updates: dict) -> dict:
|
||||||
|
"""Deep merge two dictionaries, with updates taking precedence."""
|
||||||
|
result = dict(base)
|
||||||
|
for key, value in updates.items():
|
||||||
|
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||||
|
result[key] = _deep_merge(result[key], value)
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=SYNC_PERSON)
|
||||||
|
@safe_task_execution
|
||||||
|
def sync_person(
|
||||||
|
identifier: str,
|
||||||
|
display_name: str,
|
||||||
|
aliases: list[str] | None = None,
|
||||||
|
contact_info: dict | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create or update a person in the knowledge base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: Unique slug for the person
|
||||||
|
display_name: Human-readable name
|
||||||
|
aliases: Alternative names/handles
|
||||||
|
contact_info: Contact information dict
|
||||||
|
tags: Categorization tags
|
||||||
|
notes: Free-form notes about the person
|
||||||
|
"""
|
||||||
|
logger.info(f"Syncing person: {identifier}")
|
||||||
|
|
||||||
|
# Create hash from identifier for deduplication
|
||||||
|
sha256 = create_content_hash(f"person:{identifier}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Check if person already exists by identifier
|
||||||
|
existing = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
logger.info(f"Person already exists: {identifier}")
|
||||||
|
return create_task_result(existing, "already_exists")
|
||||||
|
|
||||||
|
# Also check by sha256 (defensive)
|
||||||
|
existing_by_hash = check_content_exists(session, Person, sha256=sha256)
|
||||||
|
if existing_by_hash:
|
||||||
|
logger.info(f"Person already exists (by hash): {identifier}")
|
||||||
|
return create_task_result(existing_by_hash, "already_exists")
|
||||||
|
|
||||||
|
person = Person(
|
||||||
|
identifier=identifier,
|
||||||
|
display_name=display_name,
|
||||||
|
aliases=aliases or [],
|
||||||
|
contact_info=contact_info or {},
|
||||||
|
tags=tags or [],
|
||||||
|
content=notes,
|
||||||
|
modality="person",
|
||||||
|
mime_type="text/plain",
|
||||||
|
sha256=sha256,
|
||||||
|
size=len(notes or ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
return process_content_item(person, session)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=UPDATE_PERSON)
|
||||||
|
@safe_task_execution
|
||||||
|
def update_person(
|
||||||
|
identifier: str,
|
||||||
|
display_name: str | None = None,
|
||||||
|
aliases: list[str] | None = None,
|
||||||
|
contact_info: dict | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
replace_notes: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a person with merge semantics.
|
||||||
|
|
||||||
|
Merge behavior:
|
||||||
|
- display_name: Replaces if provided
|
||||||
|
- aliases: Union with existing
|
||||||
|
- contact_info: Deep merge with existing
|
||||||
|
- tags: Union with existing
|
||||||
|
- notes: Append to existing (or replace if replace_notes=True)
|
||||||
|
"""
|
||||||
|
logger.info(f"Updating person: {identifier}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
person = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
|
if not person:
|
||||||
|
logger.warning(f"Person not found: {identifier}")
|
||||||
|
return {"status": "not_found", "identifier": identifier}
|
||||||
|
|
||||||
|
if display_name is not None:
|
||||||
|
person.display_name = display_name
|
||||||
|
|
||||||
|
if aliases is not None:
|
||||||
|
existing_aliases = set(person.aliases or [])
|
||||||
|
new_aliases = existing_aliases | set(aliases)
|
||||||
|
person.aliases = list(new_aliases)
|
||||||
|
|
||||||
|
if contact_info is not None:
|
||||||
|
existing_contact = dict(person.contact_info or {})
|
||||||
|
person.contact_info = _deep_merge(existing_contact, contact_info)
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
existing_tags = set(person.tags or [])
|
||||||
|
new_tags = existing_tags | set(tags)
|
||||||
|
person.tags = list(new_tags)
|
||||||
|
|
||||||
|
if notes is not None:
|
||||||
|
if replace_notes or not person.content:
|
||||||
|
person.content = notes
|
||||||
|
else:
|
||||||
|
person.content = f"{person.content}\n\n---\n\n{notes}"
|
||||||
|
|
||||||
|
# Update hash based on new content
|
||||||
|
person.sha256 = create_content_hash(f"person:{identifier}")
|
||||||
|
person.size = len(person.content or "")
|
||||||
|
person.embed_status = "RAW" # Re-embed with updated content
|
||||||
|
|
||||||
|
return process_content_item(person, session)
|
||||||
0
tests/memory/api/MCP/__init__.py
Normal file
0
tests/memory/api/MCP/__init__.py
Normal file
1064
tests/memory/api/MCP/test_github.py
Normal file
1064
tests/memory/api/MCP/test_github.py
Normal file
File diff suppressed because it is too large
Load Diff
482
tests/memory/api/MCP/test_people.py
Normal file
482
tests/memory/api/MCP/test_people.py
Normal file
@ -0,0 +1,482 @@
|
|||||||
|
"""Tests for People MCP tools."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, Mock, MagicMock, patch
|
||||||
|
|
||||||
|
# Mock the mcp module and all its submodules before importing anything that uses it
|
||||||
|
_mock_mcp = MagicMock()
|
||||||
|
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
|
||||||
|
sys.modules["mcp"] = _mock_mcp
|
||||||
|
sys.modules["mcp.server"] = MagicMock()
|
||||||
|
sys.modules["mcp.server.auth"] = MagicMock()
|
||||||
|
sys.modules["mcp.server.auth.handlers"] = MagicMock()
|
||||||
|
sys.modules["mcp.server.auth.handlers.authorize"] = MagicMock()
|
||||||
|
sys.modules["mcp.server.auth.handlers.token"] = MagicMock()
|
||||||
|
sys.modules["mcp.server.auth.provider"] = MagicMock()
|
||||||
|
sys.modules["mcp.server.fastmcp"] = MagicMock()
|
||||||
|
sys.modules["mcp.server.fastmcp.server"] = MagicMock()
|
||||||
|
|
||||||
|
# Also mock the memory.api.MCP.base module to avoid MCP imports
|
||||||
|
_mock_base = MagicMock()
|
||||||
|
_mock_base.mcp = MagicMock()
|
||||||
|
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
|
||||||
|
sys.modules["memory.api.MCP.base"] = _mock_base
|
||||||
|
|
||||||
|
from memory.common.db.models import Person
|
||||||
|
from memory.common.db import connection as db_connection
|
||||||
|
from memory.workers.tasks.content_processing import create_content_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_db_cache():
|
||||||
|
"""Reset the cached database engine between tests."""
|
||||||
|
db_connection._engine = None
|
||||||
|
db_connection._session_factory = None
|
||||||
|
db_connection._scoped_session = None
|
||||||
|
yield
|
||||||
|
db_connection._engine = None
|
||||||
|
db_connection._session_factory = None
|
||||||
|
db_connection._scoped_session = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_people(db_session):
|
||||||
|
"""Create sample people for testing."""
|
||||||
|
people = [
|
||||||
|
Person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
display_name="Alice Chen",
|
||||||
|
aliases=["@alice_c", "alice.chen@work.com"],
|
||||||
|
contact_info={"email": "alice@example.com", "phone": "555-1234"},
|
||||||
|
tags=["work", "engineering"],
|
||||||
|
content="Tech lead on Platform team. Very thorough in code reviews.",
|
||||||
|
modality="person",
|
||||||
|
sha256=create_content_hash("person:alice_chen"),
|
||||||
|
size=100,
|
||||||
|
),
|
||||||
|
Person(
|
||||||
|
identifier="bob_smith",
|
||||||
|
display_name="Bob Smith",
|
||||||
|
aliases=["@bobsmith"],
|
||||||
|
contact_info={"email": "bob@example.com"},
|
||||||
|
tags=["work", "design"],
|
||||||
|
content="UX designer. Prefers visual communication.",
|
||||||
|
modality="person",
|
||||||
|
sha256=create_content_hash("person:bob_smith"),
|
||||||
|
size=50,
|
||||||
|
),
|
||||||
|
Person(
|
||||||
|
identifier="charlie_jones",
|
||||||
|
display_name="Charlie Jones",
|
||||||
|
aliases=[],
|
||||||
|
contact_info={"twitter": "@charlie_j"},
|
||||||
|
tags=["friend", "climbing"],
|
||||||
|
content="Met at climbing gym. Very reliable.",
|
||||||
|
modality="person",
|
||||||
|
sha256=create_content_hash("person:charlie_jones"),
|
||||||
|
size=30,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for person in people:
|
||||||
|
db_session.add(person)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
for person in people:
|
||||||
|
db_session.refresh(person)
|
||||||
|
|
||||||
|
return people
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for add_person
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_person_success(db_session):
|
||||||
|
"""Test adding a new person."""
|
||||||
|
from memory.api.MCP.people import add_person
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "task-123"
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
with patch("memory.api.MCP.people.celery_app") as mock_celery:
|
||||||
|
mock_celery.send_task.return_value = mock_task
|
||||||
|
result = await add_person(
|
||||||
|
identifier="new_person",
|
||||||
|
display_name="New Person",
|
||||||
|
aliases=["@newperson"],
|
||||||
|
contact_info={"email": "new@example.com"},
|
||||||
|
tags=["friend"],
|
||||||
|
notes="A new friend.",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "queued"
|
||||||
|
assert result["task_id"] == "task-123"
|
||||||
|
assert result["identifier"] == "new_person"
|
||||||
|
|
||||||
|
# Verify Celery task was called
|
||||||
|
mock_celery.send_task.assert_called_once()
|
||||||
|
call_kwargs = mock_celery.send_task.call_args[1]
|
||||||
|
assert call_kwargs["kwargs"]["identifier"] == "new_person"
|
||||||
|
assert call_kwargs["kwargs"]["display_name"] == "New Person"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_person_already_exists(db_session, sample_people):
|
||||||
|
"""Test adding a person that already exists."""
|
||||||
|
from memory.api.MCP.people import add_person
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
await add_person(
|
||||||
|
identifier="alice_chen", # Already exists
|
||||||
|
display_name="Alice Chen Duplicate",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_person_minimal(db_session):
|
||||||
|
"""Test adding a person with minimal data."""
|
||||||
|
from memory.api.MCP.people import add_person
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "task-456"
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
with patch("memory.api.MCP.people.celery_app") as mock_celery:
|
||||||
|
mock_celery.send_task.return_value = mock_task
|
||||||
|
result = await add_person(
|
||||||
|
identifier="minimal_person",
|
||||||
|
display_name="Minimal Person",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "queued"
|
||||||
|
assert result["identifier"] == "minimal_person"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for update_person_info
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_person_info_success(db_session, sample_people):
|
||||||
|
"""Test updating a person's info."""
|
||||||
|
from memory.api.MCP.people import update_person_info
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "task-789"
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
with patch("memory.api.MCP.people.celery_app") as mock_celery:
|
||||||
|
mock_celery.send_task.return_value = mock_task
|
||||||
|
result = await update_person_info(
|
||||||
|
identifier="alice_chen",
|
||||||
|
display_name="Alice M. Chen",
|
||||||
|
notes="Added middle initial",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "queued"
|
||||||
|
assert result["task_id"] == "task-789"
|
||||||
|
assert result["identifier"] == "alice_chen"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_person_info_not_found(db_session, sample_people):
|
||||||
|
"""Test updating a person that doesn't exist."""
|
||||||
|
from memory.api.MCP.people import update_person_info
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await update_person_info(
|
||||||
|
identifier="nonexistent_person",
|
||||||
|
display_name="New Name",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_person_info_with_merge_params(db_session, sample_people):
|
||||||
|
"""Test that update passes all merge parameters."""
|
||||||
|
from memory.api.MCP.people import update_person_info
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "task-merge"
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
with patch("memory.api.MCP.people.celery_app") as mock_celery:
|
||||||
|
mock_celery.send_task.return_value = mock_task
|
||||||
|
result = await update_person_info(
|
||||||
|
identifier="alice_chen",
|
||||||
|
aliases=["@alice_new"],
|
||||||
|
contact_info={"slack": "@alice"},
|
||||||
|
tags=["leadership"],
|
||||||
|
notes="Promoted to senior",
|
||||||
|
replace_notes=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = mock_celery.send_task.call_args[1]["kwargs"]
|
||||||
|
assert call_kwargs["aliases"] == ["@alice_new"]
|
||||||
|
assert call_kwargs["contact_info"] == {"slack": "@alice"}
|
||||||
|
assert call_kwargs["tags"] == ["leadership"]
|
||||||
|
assert call_kwargs["notes"] == "Promoted to senior"
|
||||||
|
assert call_kwargs["replace_notes"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for get_person
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_person_found(db_session, sample_people):
|
||||||
|
"""Test getting a person that exists."""
|
||||||
|
from memory.api.MCP.people import get_person
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
result = await get_person(identifier="alice_chen")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["identifier"] == "alice_chen"
|
||||||
|
assert result["display_name"] == "Alice Chen"
|
||||||
|
assert result["aliases"] == ["@alice_c", "alice.chen@work.com"]
|
||||||
|
assert result["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
|
||||||
|
assert result["tags"] == ["work", "engineering"]
|
||||||
|
assert "Tech lead" in result["notes"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_person_not_found(db_session, sample_people):
|
||||||
|
"""Test getting a person that doesn't exist."""
|
||||||
|
from memory.api.MCP.people import get_person
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
result = await get_person(identifier="nonexistent_person")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for list_people
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_people_no_filters(db_session, sample_people):
|
||||||
|
"""Test listing all people without filters."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people()
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
# Should be ordered by display_name
|
||||||
|
assert results[0]["display_name"] == "Alice Chen"
|
||||||
|
assert results[1]["display_name"] == "Bob Smith"
|
||||||
|
assert results[2]["display_name"] == "Charlie Jones"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_people_filter_by_tags(db_session, sample_people):
|
||||||
|
"""Test filtering by tags."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people(tags=["work"])
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert all("work" in r["tags"] for r in results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_people_filter_by_search(db_session, sample_people):
|
||||||
|
"""Test filtering by search term."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people(search="alice")
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["identifier"] == "alice_chen"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_people_search_in_notes(db_session, sample_people):
|
||||||
|
"""Test that search works on notes content."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people(search="climbing")
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["identifier"] == "charlie_jones"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_people_limit(db_session, sample_people):
|
||||||
|
"""Test limiting results."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people(limit=1)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_people_limit_max_enforced(db_session, sample_people):
|
||||||
|
"""Test that limit is capped at 200."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
# Request 500 but should be capped at 200
|
||||||
|
results = await list_people(limit=500)
|
||||||
|
|
||||||
|
# We only have 3 people, but the limit logic should cap at 200
|
||||||
|
assert len(results) <= 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_people_combined_filters(db_session, sample_people):
|
||||||
|
"""Test combining tag and search filters."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people(tags=["work"], search="chen")
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["identifier"] == "alice_chen"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for delete_person
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_person_success(db_session, sample_people):
|
||||||
|
"""Test deleting a person."""
|
||||||
|
from memory.api.MCP.people import delete_person
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
result = await delete_person(identifier="bob_smith")
|
||||||
|
|
||||||
|
assert result["deleted"] is True
|
||||||
|
assert result["identifier"] == "bob_smith"
|
||||||
|
assert result["display_name"] == "Bob Smith"
|
||||||
|
|
||||||
|
# Verify person was deleted
|
||||||
|
remaining = db_session.query(Person).filter_by(identifier="bob_smith").first()
|
||||||
|
assert remaining is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_person_not_found(db_session, sample_people):
|
||||||
|
"""Test deleting a person that doesn't exist."""
|
||||||
|
from memory.api.MCP.people import delete_person
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await delete_person(identifier="nonexistent_person")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for _person_to_dict helper
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_to_dict(sample_people):
|
||||||
|
"""Test the _person_to_dict helper function."""
|
||||||
|
from memory.api.MCP.people import _person_to_dict
|
||||||
|
|
||||||
|
person = sample_people[0]
|
||||||
|
result = _person_to_dict(person)
|
||||||
|
|
||||||
|
assert result["identifier"] == "alice_chen"
|
||||||
|
assert result["display_name"] == "Alice Chen"
|
||||||
|
assert result["aliases"] == ["@alice_c", "alice.chen@work.com"]
|
||||||
|
assert result["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
|
||||||
|
assert result["tags"] == ["work", "engineering"]
|
||||||
|
assert result["notes"] == "Tech lead on Platform team. Very thorough in code reviews."
|
||||||
|
assert result["created_at"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_to_dict_empty_fields(db_session):
|
||||||
|
"""Test _person_to_dict with empty optional fields."""
|
||||||
|
from memory.api.MCP.people import _person_to_dict
|
||||||
|
|
||||||
|
person = Person(
|
||||||
|
identifier="empty_person",
|
||||||
|
display_name="Empty Person",
|
||||||
|
aliases=[],
|
||||||
|
contact_info={},
|
||||||
|
tags=[],
|
||||||
|
content=None,
|
||||||
|
modality="person",
|
||||||
|
sha256=create_content_hash("person:empty_person"),
|
||||||
|
size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _person_to_dict(person)
|
||||||
|
|
||||||
|
assert result["identifier"] == "empty_person"
|
||||||
|
assert result["aliases"] == []
|
||||||
|
assert result["contact_info"] == {}
|
||||||
|
assert result["tags"] == []
|
||||||
|
assert result["notes"] is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Parametrized tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tag,expected_count",
|
||||||
|
[
|
||||||
|
("work", 2),
|
||||||
|
("engineering", 1),
|
||||||
|
("design", 1),
|
||||||
|
("friend", 1),
|
||||||
|
("climbing", 1),
|
||||||
|
("nonexistent", 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_list_people_various_tags(db_session, sample_people, tag, expected_count):
|
||||||
|
"""Test filtering by various tags."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people(tags=[tag])
|
||||||
|
|
||||||
|
assert len(results) == expected_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"search_term,expected_identifiers",
|
||||||
|
[
|
||||||
|
("alice", ["alice_chen"]),
|
||||||
|
("bob", ["bob_smith"]),
|
||||||
|
("smith", ["bob_smith"]),
|
||||||
|
("chen", ["alice_chen"]),
|
||||||
|
("jones", ["charlie_jones"]),
|
||||||
|
("example.com", []), # Not searching in contact_info
|
||||||
|
("UX", ["bob_smith"]), # Case insensitive search in notes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_list_people_various_searches(
|
||||||
|
db_session, sample_people, search_term, expected_identifiers
|
||||||
|
):
|
||||||
|
"""Test search with various terms."""
|
||||||
|
from memory.api.MCP.people import list_people
|
||||||
|
|
||||||
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
|
results = await list_people(search=search_term)
|
||||||
|
|
||||||
|
result_identifiers = [r["identifier"] for r in results]
|
||||||
|
assert result_identifiers == expected_identifiers
|
||||||
0
tests/memory/api/__init__.py
Normal file
0
tests/memory/api/__init__.py
Normal file
249
tests/memory/common/db/models/test_people.py
Normal file
249
tests/memory/common/db/models/test_people.py
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
"""Tests for the Person model."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from memory.common.db.models import Person
|
||||||
|
from memory.workers.tasks.content_processing import create_content_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def person_data():
|
||||||
|
"""Standard person test data."""
|
||||||
|
return {
|
||||||
|
"identifier": "alice_chen",
|
||||||
|
"display_name": "Alice Chen",
|
||||||
|
"aliases": ["@alice_c", "alice.chen@work.com"],
|
||||||
|
"contact_info": {"email": "alice@example.com", "phone": "555-1234"},
|
||||||
|
"tags": ["work", "engineering"],
|
||||||
|
"content": "Tech lead on Platform team. Very thorough in code reviews.",
|
||||||
|
"modality": "person",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def minimal_person_data():
|
||||||
|
"""Minimal person test data."""
|
||||||
|
return {
|
||||||
|
"identifier": "bob_smith",
|
||||||
|
"display_name": "Bob Smith",
|
||||||
|
"modality": "person",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_creation(person_data):
|
||||||
|
"""Test creating a Person with all fields."""
|
||||||
|
sha256 = create_content_hash(f"person:{person_data['identifier']}")
|
||||||
|
person = Person(**person_data, sha256=sha256, size=100)
|
||||||
|
|
||||||
|
assert person.identifier == "alice_chen"
|
||||||
|
assert person.display_name == "Alice Chen"
|
||||||
|
assert person.aliases == ["@alice_c", "alice.chen@work.com"]
|
||||||
|
assert person.contact_info == {"email": "alice@example.com", "phone": "555-1234"}
|
||||||
|
assert person.tags == ["work", "engineering"]
|
||||||
|
assert person.content == "Tech lead on Platform team. Very thorough in code reviews."
|
||||||
|
assert person.modality == "person"
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_creation_minimal(minimal_person_data):
|
||||||
|
"""Test creating a Person with minimal fields."""
|
||||||
|
sha256 = create_content_hash(f"person:{minimal_person_data['identifier']}")
|
||||||
|
person = Person(**minimal_person_data, sha256=sha256, size=0)
|
||||||
|
|
||||||
|
assert person.identifier == "bob_smith"
|
||||||
|
assert person.display_name == "Bob Smith"
|
||||||
|
assert person.aliases == [] or person.aliases is None
|
||||||
|
assert person.contact_info == {} or person.contact_info is None
|
||||||
|
assert person.tags == [] or person.tags is None
|
||||||
|
assert person.content is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_as_payload(person_data):
|
||||||
|
"""Test the as_payload method."""
|
||||||
|
sha256 = create_content_hash(f"person:{person_data['identifier']}")
|
||||||
|
person = Person(**person_data, sha256=sha256, size=100)
|
||||||
|
|
||||||
|
payload = person.as_payload()
|
||||||
|
|
||||||
|
assert payload["identifier"] == "alice_chen"
|
||||||
|
assert payload["display_name"] == "Alice Chen"
|
||||||
|
assert payload["aliases"] == ["@alice_c", "alice.chen@work.com"]
|
||||||
|
assert payload["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_display_contents(person_data):
|
||||||
|
"""Test the display_contents property."""
|
||||||
|
sha256 = create_content_hash(f"person:{person_data['identifier']}")
|
||||||
|
person = Person(**person_data, sha256=sha256, size=100)
|
||||||
|
|
||||||
|
contents = person.display_contents
|
||||||
|
|
||||||
|
assert contents["identifier"] == "alice_chen"
|
||||||
|
assert contents["display_name"] == "Alice Chen"
|
||||||
|
assert contents["aliases"] == ["@alice_c", "alice.chen@work.com"]
|
||||||
|
assert contents["contact_info"] == {"email": "alice@example.com", "phone": "555-1234"}
|
||||||
|
assert contents["notes"] == "Tech lead on Platform team. Very thorough in code reviews."
|
||||||
|
assert contents["tags"] == ["work", "engineering"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_chunk_contents(person_data):
|
||||||
|
"""Test the _chunk_contents method generates searchable chunks."""
|
||||||
|
sha256 = create_content_hash(f"person:{person_data['identifier']}")
|
||||||
|
person = Person(**person_data, sha256=sha256, size=100)
|
||||||
|
|
||||||
|
chunks = person._chunk_contents()
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
chunk_text = chunks[0].data[0]
|
||||||
|
|
||||||
|
# Should include display name
|
||||||
|
assert "Alice Chen" in chunk_text
|
||||||
|
# Should include aliases
|
||||||
|
assert "@alice_c" in chunk_text
|
||||||
|
# Should include tags
|
||||||
|
assert "work" in chunk_text
|
||||||
|
# Should include notes/content
|
||||||
|
assert "Tech lead" in chunk_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_chunk_contents_minimal(minimal_person_data):
|
||||||
|
"""Test _chunk_contents with minimal data."""
|
||||||
|
sha256 = create_content_hash(f"person:{minimal_person_data['identifier']}")
|
||||||
|
person = Person(**minimal_person_data, sha256=sha256, size=0)
|
||||||
|
|
||||||
|
chunks = person._chunk_contents()
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
chunk_text = chunks[0].data[0]
|
||||||
|
assert "Bob Smith" in chunk_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_get_collections():
|
||||||
|
"""Test that Person returns correct collections."""
|
||||||
|
collections = Person.get_collections()
|
||||||
|
|
||||||
|
assert collections == ["person"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_polymorphic_identity():
|
||||||
|
"""Test that Person has correct polymorphic identity."""
|
||||||
|
assert Person.__mapper_args__["polymorphic_identity"] == "person"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"identifier,display_name,aliases,tags",
|
||||||
|
[
|
||||||
|
("john_doe", "John Doe", [], []),
|
||||||
|
("jane_smith", "Jane Smith", ["@jane"], ["friend"]),
|
||||||
|
("bob_jones", "Bob Jones", ["@bob", "bobby"], ["work", "climbing", "london"]),
|
||||||
|
(
|
||||||
|
"alice_wong",
|
||||||
|
"Alice Wong",
|
||||||
|
["@alice", "alice@work.com", "Alice W."],
|
||||||
|
["family", "close"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_person_various_configurations(identifier, display_name, aliases, tags):
|
||||||
|
"""Test Person creation with various configurations."""
|
||||||
|
sha256 = create_content_hash(f"person:{identifier}")
|
||||||
|
person = Person(
|
||||||
|
identifier=identifier,
|
||||||
|
display_name=display_name,
|
||||||
|
aliases=aliases,
|
||||||
|
tags=tags,
|
||||||
|
modality="person",
|
||||||
|
sha256=sha256,
|
||||||
|
size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert person.identifier == identifier
|
||||||
|
assert person.display_name == display_name
|
||||||
|
assert person.aliases == aliases
|
||||||
|
assert person.tags == tags
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_contact_info_flexible():
|
||||||
|
"""Test that contact_info can hold various structures."""
|
||||||
|
contact_info = {
|
||||||
|
"email": "test@example.com",
|
||||||
|
"phone": "+1-555-1234",
|
||||||
|
"twitter": "@testuser",
|
||||||
|
"linkedin": "linkedin.com/in/testuser",
|
||||||
|
"address": {
|
||||||
|
"street": "123 Main St",
|
||||||
|
"city": "San Francisco",
|
||||||
|
"country": "USA",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sha256 = create_content_hash("person:test_user")
|
||||||
|
person = Person(
|
||||||
|
identifier="test_user",
|
||||||
|
display_name="Test User",
|
||||||
|
contact_info=contact_info,
|
||||||
|
modality="person",
|
||||||
|
sha256=sha256,
|
||||||
|
size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert person.contact_info == contact_info
|
||||||
|
assert person.contact_info["address"]["city"] == "San Francisco"
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_in_db(db_session, qdrant):
|
||||||
|
"""Test Person persistence in database."""
|
||||||
|
sha256 = create_content_hash("person:db_test_user")
|
||||||
|
person = Person(
|
||||||
|
identifier="db_test_user",
|
||||||
|
display_name="DB Test User",
|
||||||
|
aliases=["@dbtest"],
|
||||||
|
contact_info={"email": "dbtest@example.com"},
|
||||||
|
tags=["test"],
|
||||||
|
content="Test notes",
|
||||||
|
modality="person",
|
||||||
|
mime_type="text/plain",
|
||||||
|
sha256=sha256,
|
||||||
|
size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(person)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Query it back
|
||||||
|
retrieved = db_session.query(Person).filter_by(identifier="db_test_user").first()
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.display_name == "DB Test User"
|
||||||
|
assert retrieved.aliases == ["@dbtest"]
|
||||||
|
assert retrieved.contact_info == {"email": "dbtest@example.com"}
|
||||||
|
assert retrieved.tags == ["test"]
|
||||||
|
assert retrieved.content == "Test notes"
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_unique_identifier(db_session, qdrant):
|
||||||
|
"""Test that identifier must be unique."""
|
||||||
|
sha256 = create_content_hash("person:unique_test")
|
||||||
|
|
||||||
|
person1 = Person(
|
||||||
|
identifier="unique_test",
|
||||||
|
display_name="Person 1",
|
||||||
|
modality="person",
|
||||||
|
sha256=sha256,
|
||||||
|
size=0,
|
||||||
|
)
|
||||||
|
db_session.add(person1)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Try to add another with same identifier
|
||||||
|
person2 = Person(
|
||||||
|
identifier="unique_test",
|
||||||
|
display_name="Person 2",
|
||||||
|
modality="person",
|
||||||
|
sha256=create_content_hash("person:unique_test_2"),
|
||||||
|
size=0,
|
||||||
|
)
|
||||||
|
db_session.add(person2)
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # Should raise IntegrityError
|
||||||
|
db_session.commit()
|
||||||
@ -333,7 +333,13 @@ def test_fetch_prs_basic():
|
|||||||
|
|
||||||
page = kwargs.get("params", {}).get("page", 1)
|
page = kwargs.get("params", {}).get("page", 1)
|
||||||
|
|
||||||
if "/pulls" in url and "/comments" not in url:
|
# PR list endpoint
|
||||||
|
if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url:
|
||||||
|
# Check if this is the diff request
|
||||||
|
if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff":
|
||||||
|
response.ok = True
|
||||||
|
response.text = "+100 lines added\n-50 lines removed"
|
||||||
|
return response
|
||||||
if page == 1:
|
if page == 1:
|
||||||
response.json.return_value = [
|
response.json.return_value = [
|
||||||
{
|
{
|
||||||
@ -355,10 +361,22 @@ def test_fetch_prs_basic():
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
response.json.return_value = []
|
response.json.return_value = []
|
||||||
elif ".diff" in url:
|
elif "/pulls/" in url and "/comments" in url:
|
||||||
response.ok = True
|
# Review comments endpoint
|
||||||
response.text = "+100 lines added\n-50 lines removed"
|
response.json.return_value = []
|
||||||
elif "/comments" in url:
|
elif "/pulls/" in url and "/reviews" in url:
|
||||||
|
# Reviews endpoint
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/" in url and "/files" in url:
|
||||||
|
# Files endpoint
|
||||||
|
if page == 1:
|
||||||
|
response.json.return_value = [
|
||||||
|
{"filename": "test.py", "status": "added", "additions": 100, "deletions": 50, "patch": "+code"}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/issues/" in url and "/comments" in url:
|
||||||
|
# Regular comments endpoint
|
||||||
response.json.return_value = []
|
response.json.return_value = []
|
||||||
else:
|
else:
|
||||||
response.json.return_value = []
|
response.json.return_value = []
|
||||||
@ -376,43 +394,66 @@ def test_fetch_prs_basic():
|
|||||||
assert pr["kind"] == "pr"
|
assert pr["kind"] == "pr"
|
||||||
assert pr["diff_summary"] is not None
|
assert pr["diff_summary"] is not None
|
||||||
assert "100 lines added" in pr["diff_summary"]
|
assert "100 lines added" in pr["diff_summary"]
|
||||||
|
# Verify pr_data is populated
|
||||||
|
assert pr["pr_data"] is not None
|
||||||
|
assert pr["pr_data"]["additions"] == 100
|
||||||
|
assert pr["pr_data"]["deletions"] == 50
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_prs_merged():
|
def test_fetch_prs_merged():
|
||||||
"""Test fetching merged PR."""
|
"""Test fetching merged PR."""
|
||||||
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
mock_response = Mock()
|
def mock_get(url, **kwargs):
|
||||||
mock_response.json.return_value = [
|
"""Route mock responses based on URL."""
|
||||||
{
|
response = Mock()
|
||||||
"number": 20,
|
response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
"title": "Merged PR",
|
response.raise_for_status = Mock()
|
||||||
"body": "Body",
|
|
||||||
"state": "closed",
|
|
||||||
"user": {"login": "user"},
|
|
||||||
"labels": [],
|
|
||||||
"assignees": [],
|
|
||||||
"milestone": None,
|
|
||||||
"created_at": "2024-01-01T00:00:00Z",
|
|
||||||
"updated_at": "2024-01-10T00:00:00Z",
|
|
||||||
"closed_at": "2024-01-10T00:00:00Z",
|
|
||||||
"merged_at": "2024-01-10T00:00:00Z",
|
|
||||||
"additions": 10,
|
|
||||||
"deletions": 5,
|
|
||||||
"comments": 0,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
|
||||||
mock_response.raise_for_status = Mock()
|
|
||||||
|
|
||||||
mock_empty = Mock()
|
page = kwargs.get("params", {}).get("page", 1)
|
||||||
mock_empty.json.return_value = []
|
|
||||||
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
|
||||||
mock_empty.raise_for_status = Mock()
|
|
||||||
|
|
||||||
with patch.object(requests.Session, "get") as mock_get:
|
# PR list endpoint
|
||||||
mock_get.side_effect = [mock_response, mock_empty, mock_empty]
|
if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url:
|
||||||
|
if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff":
|
||||||
|
response.ok = True
|
||||||
|
response.text = ""
|
||||||
|
return response
|
||||||
|
if page == 1:
|
||||||
|
response.json.return_value = [
|
||||||
|
{
|
||||||
|
"number": 20,
|
||||||
|
"title": "Merged PR",
|
||||||
|
"body": "Body",
|
||||||
|
"state": "closed",
|
||||||
|
"user": {"login": "user"},
|
||||||
|
"labels": [],
|
||||||
|
"assignees": [],
|
||||||
|
"milestone": None,
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"updated_at": "2024-01-10T00:00:00Z",
|
||||||
|
"closed_at": "2024-01-10T00:00:00Z",
|
||||||
|
"merged_at": "2024-01-10T00:00:00Z",
|
||||||
|
"additions": 10,
|
||||||
|
"deletions": 5,
|
||||||
|
"comments": 0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/" in url and "/comments" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/" in url and "/reviews" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/" in url and "/files" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/issues/" in url and "/comments" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get", side_effect=mock_get):
|
||||||
client = GithubClient(credentials)
|
client = GithubClient(credentials)
|
||||||
prs = list(client.fetch_prs("owner", "repo"))
|
prs = list(client.fetch_prs("owner", "repo"))
|
||||||
|
|
||||||
@ -425,54 +466,73 @@ def test_fetch_prs_stops_at_since():
|
|||||||
"""Test that PR fetching stops when reaching older items."""
|
"""Test that PR fetching stops when reaching older items."""
|
||||||
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
mock_response = Mock()
|
def mock_get(url, **kwargs):
|
||||||
mock_response.json.return_value = [
|
"""Route mock responses based on URL."""
|
||||||
{
|
response = Mock()
|
||||||
"number": 30,
|
response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
"title": "Recent PR",
|
response.raise_for_status = Mock()
|
||||||
"body": "Body",
|
|
||||||
"state": "open",
|
|
||||||
"user": {"login": "user"},
|
|
||||||
"labels": [],
|
|
||||||
"assignees": [],
|
|
||||||
"milestone": None,
|
|
||||||
"created_at": "2024-01-20T00:00:00Z",
|
|
||||||
"updated_at": "2024-01-20T00:00:00Z",
|
|
||||||
"closed_at": None,
|
|
||||||
"merged_at": None,
|
|
||||||
"additions": 1,
|
|
||||||
"deletions": 1,
|
|
||||||
"comments": 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"number": 29,
|
|
||||||
"title": "Old PR",
|
|
||||||
"body": "Body",
|
|
||||||
"state": "open",
|
|
||||||
"user": {"login": "user"},
|
|
||||||
"labels": [],
|
|
||||||
"assignees": [],
|
|
||||||
"milestone": None,
|
|
||||||
"created_at": "2024-01-01T00:00:00Z",
|
|
||||||
"updated_at": "2024-01-01T00:00:00Z", # Older than since
|
|
||||||
"closed_at": None,
|
|
||||||
"merged_at": None,
|
|
||||||
"additions": 1,
|
|
||||||
"deletions": 1,
|
|
||||||
"comments": 0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
|
||||||
mock_response.raise_for_status = Mock()
|
|
||||||
|
|
||||||
mock_empty = Mock()
|
page = kwargs.get("params", {}).get("page", 1)
|
||||||
mock_empty.json.return_value = []
|
|
||||||
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
|
||||||
mock_empty.raise_for_status = Mock()
|
|
||||||
|
|
||||||
with patch.object(requests.Session, "get") as mock_get:
|
# PR list endpoint
|
||||||
mock_get.side_effect = [mock_response, mock_empty]
|
if "/pulls" in url and "/comments" not in url and "/reviews" not in url and "/files" not in url:
|
||||||
|
if kwargs.get("headers", {}).get("Accept") == "application/vnd.github.diff":
|
||||||
|
response.ok = True
|
||||||
|
response.text = ""
|
||||||
|
return response
|
||||||
|
if page == 1:
|
||||||
|
response.json.return_value = [
|
||||||
|
{
|
||||||
|
"number": 30,
|
||||||
|
"title": "Recent PR",
|
||||||
|
"body": "Body",
|
||||||
|
"state": "open",
|
||||||
|
"user": {"login": "user"},
|
||||||
|
"labels": [],
|
||||||
|
"assignees": [],
|
||||||
|
"milestone": None,
|
||||||
|
"created_at": "2024-01-20T00:00:00Z",
|
||||||
|
"updated_at": "2024-01-20T00:00:00Z",
|
||||||
|
"closed_at": None,
|
||||||
|
"merged_at": None,
|
||||||
|
"additions": 1,
|
||||||
|
"deletions": 1,
|
||||||
|
"comments": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"number": 29,
|
||||||
|
"title": "Old PR",
|
||||||
|
"body": "Body",
|
||||||
|
"state": "open",
|
||||||
|
"user": {"login": "user"},
|
||||||
|
"labels": [],
|
||||||
|
"assignees": [],
|
||||||
|
"milestone": None,
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"updated_at": "2024-01-01T00:00:00Z", # Older than since
|
||||||
|
"closed_at": None,
|
||||||
|
"merged_at": None,
|
||||||
|
"additions": 1,
|
||||||
|
"deletions": 1,
|
||||||
|
"comments": 0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/" in url and "/comments" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/" in url and "/reviews" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/" in url and "/files" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/issues/" in url and "/comments" in url:
|
||||||
|
response.json.return_value = []
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get", side_effect=mock_get):
|
||||||
client = GithubClient(credentials)
|
client = GithubClient(credentials)
|
||||||
since = datetime(2024, 1, 15, tzinfo=timezone.utc)
|
since = datetime(2024, 1, 15, tzinfo=timezone.utc)
|
||||||
prs = list(client.fetch_prs("owner", "repo", since=since))
|
prs = list(client.fetch_prs("owner", "repo", since=since))
|
||||||
@ -703,3 +763,552 @@ def test_fetch_issues_handles_api_error():
|
|||||||
|
|
||||||
with pytest.raises(requests.HTTPError):
|
with pytest.raises(requests.HTTPError):
|
||||||
list(client.fetch_issues("owner", "nonexistent"))
|
list(client.fetch_issues("owner", "nonexistent"))
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for fetch_review_comments
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_review_comments_basic():
|
||||||
|
"""Test fetching PR review comments."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 1001,
|
||||||
|
"user": {"login": "reviewer1"},
|
||||||
|
"body": "This needs a test",
|
||||||
|
"path": "src/main.py",
|
||||||
|
"line": 42,
|
||||||
|
"side": "RIGHT",
|
||||||
|
"diff_hunk": "@@ -40,3 +40,5 @@",
|
||||||
|
"created_at": "2024-01-01T12:00:00Z",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1002,
|
||||||
|
"user": {"login": "reviewer2"},
|
||||||
|
"body": "Good refactoring",
|
||||||
|
"path": "src/utils.py",
|
||||||
|
"line": 10,
|
||||||
|
"side": "LEFT",
|
||||||
|
"diff_hunk": "@@ -8,5 +8,5 @@",
|
||||||
|
"created_at": "2024-01-02T10:00:00Z",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_empty = Mock()
|
||||||
|
mock_empty.json.return_value = []
|
||||||
|
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
||||||
|
mock_empty.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = [mock_response, mock_empty]
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
comments = client.fetch_review_comments("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert len(comments) == 2
|
||||||
|
assert comments[0]["user"] == "reviewer1"
|
||||||
|
assert comments[0]["body"] == "This needs a test"
|
||||||
|
assert comments[0]["path"] == "src/main.py"
|
||||||
|
assert comments[0]["line"] == 42
|
||||||
|
assert comments[0]["side"] == "RIGHT"
|
||||||
|
assert comments[0]["diff_hunk"] == "@@ -40,3 +40,5 @@"
|
||||||
|
assert comments[1]["user"] == "reviewer2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_review_comments_ghost_user():
|
||||||
|
"""Test review comments with deleted user."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 1001,
|
||||||
|
"user": None, # Deleted user
|
||||||
|
"body": "Legacy comment",
|
||||||
|
"path": "file.py",
|
||||||
|
"line": None, # Line might be None for outdated comments
|
||||||
|
"side": "RIGHT",
|
||||||
|
"diff_hunk": "",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_empty = Mock()
|
||||||
|
mock_empty.json.return_value = []
|
||||||
|
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
||||||
|
mock_empty.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = [mock_response, mock_empty]
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
comments = client.fetch_review_comments("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert len(comments) == 1
|
||||||
|
assert comments[0]["user"] == "ghost"
|
||||||
|
assert comments[0]["line"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_review_comments_pagination():
|
||||||
|
"""Test review comment fetching with pagination."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_page1 = Mock()
|
||||||
|
mock_page1.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": i,
|
||||||
|
"user": {"login": f"user{i}"},
|
||||||
|
"body": f"Comment {i}",
|
||||||
|
"path": "file.py",
|
||||||
|
"line": i,
|
||||||
|
"side": "RIGHT",
|
||||||
|
"diff_hunk": "",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
for i in range(100)
|
||||||
|
]
|
||||||
|
mock_page1.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
mock_page1.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_page2 = Mock()
|
||||||
|
mock_page2.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"user": {"login": "user100"},
|
||||||
|
"body": "Final comment",
|
||||||
|
"path": "file.py",
|
||||||
|
"line": 100,
|
||||||
|
"side": "RIGHT",
|
||||||
|
"diff_hunk": "",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_page2.headers = {"X-RateLimit-Remaining": "4998"}
|
||||||
|
mock_page2.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_empty = Mock()
|
||||||
|
mock_empty.json.return_value = []
|
||||||
|
mock_empty.headers = {"X-RateLimit-Remaining": "4997"}
|
||||||
|
mock_empty.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = [mock_page1, mock_page2, mock_empty]
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
comments = client.fetch_review_comments("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert len(comments) == 101
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for fetch_reviews
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_reviews_basic():
|
||||||
|
"""Test fetching PR reviews."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 2001,
|
||||||
|
"user": {"login": "lead_dev"},
|
||||||
|
"state": "APPROVED",
|
||||||
|
"body": "LGTM!",
|
||||||
|
"submitted_at": "2024-01-05T15:00:00Z",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2002,
|
||||||
|
"user": {"login": "qa_engineer"},
|
||||||
|
"state": "CHANGES_REQUESTED",
|
||||||
|
"body": "Please add tests",
|
||||||
|
"submitted_at": "2024-01-04T10:00:00Z",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2003,
|
||||||
|
"user": {"login": "observer"},
|
||||||
|
"state": "COMMENTED",
|
||||||
|
"body": None, # Some reviews have no body
|
||||||
|
"submitted_at": "2024-01-03T08:00:00Z",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_empty = Mock()
|
||||||
|
mock_empty.json.return_value = []
|
||||||
|
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
||||||
|
mock_empty.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = [mock_response, mock_empty]
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
reviews = client.fetch_reviews("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert len(reviews) == 3
|
||||||
|
assert reviews[0]["user"] == "lead_dev"
|
||||||
|
assert reviews[0]["state"] == "approved" # Lowercased
|
||||||
|
assert reviews[0]["body"] == "LGTM!"
|
||||||
|
assert reviews[1]["state"] == "changes_requested"
|
||||||
|
assert reviews[2]["body"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_reviews_ghost_user():
|
||||||
|
"""Test reviews with deleted user."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 2001,
|
||||||
|
"user": None,
|
||||||
|
"state": "APPROVED",
|
||||||
|
"body": "Approved by former employee",
|
||||||
|
"submitted_at": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_empty = Mock()
|
||||||
|
mock_empty.json.return_value = []
|
||||||
|
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
||||||
|
mock_empty.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = [mock_response, mock_empty]
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
reviews = client.fetch_reviews("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert len(reviews) == 1
|
||||||
|
assert reviews[0]["user"] == "ghost"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for fetch_pr_files
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_pr_files_basic():
|
||||||
|
"""Test fetching PR file changes."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"filename": "src/main.py",
|
||||||
|
"status": "modified",
|
||||||
|
"additions": 10,
|
||||||
|
"deletions": 5,
|
||||||
|
"patch": "@@ -1,5 +1,10 @@\n+new code\n-old code",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "src/new_feature.py",
|
||||||
|
"status": "added",
|
||||||
|
"additions": 100,
|
||||||
|
"deletions": 0,
|
||||||
|
"patch": "@@ -0,0 +1,100 @@\n+entire new file",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "old_file.py",
|
||||||
|
"status": "removed",
|
||||||
|
"additions": 0,
|
||||||
|
"deletions": 50,
|
||||||
|
"patch": "@@ -1,50 +0,0 @@\n-entire old file",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "image.png",
|
||||||
|
"status": "added",
|
||||||
|
"additions": 0,
|
||||||
|
"deletions": 0,
|
||||||
|
# No patch for binary files
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_empty = Mock()
|
||||||
|
mock_empty.json.return_value = []
|
||||||
|
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
||||||
|
mock_empty.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = [mock_response, mock_empty]
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
files = client.fetch_pr_files("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert len(files) == 4
|
||||||
|
assert files[0]["filename"] == "src/main.py"
|
||||||
|
assert files[0]["status"] == "modified"
|
||||||
|
assert files[0]["additions"] == 10
|
||||||
|
assert files[0]["deletions"] == 5
|
||||||
|
assert files[0]["patch"] is not None
|
||||||
|
|
||||||
|
assert files[1]["status"] == "added"
|
||||||
|
assert files[2]["status"] == "removed"
|
||||||
|
assert files[3]["patch"] is None # Binary file
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_pr_files_renamed():
|
||||||
|
"""Test PR with renamed files."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"filename": "new_name.py",
|
||||||
|
"status": "renamed",
|
||||||
|
"additions": 0,
|
||||||
|
"deletions": 0,
|
||||||
|
"patch": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_empty = Mock()
|
||||||
|
mock_empty.json.return_value = []
|
||||||
|
mock_empty.headers = {"X-RateLimit-Remaining": "4998"}
|
||||||
|
mock_empty.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = [mock_response, mock_empty]
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
files = client.fetch_pr_files("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert len(files) == 1
|
||||||
|
assert files[0]["status"] == "renamed"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for fetch_pr_diff
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_pr_diff_success():
|
||||||
|
"""Test fetching full PR diff."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
diff_text = """diff --git a/file.py b/file.py
|
||||||
|
index abc123..def456 100644
|
||||||
|
--- a/file.py
|
||||||
|
+++ b/file.py
|
||||||
|
@@ -1,5 +1,10 @@
|
||||||
|
+import os
|
||||||
|
+
|
||||||
|
def main():
|
||||||
|
- print("old")
|
||||||
|
+ print("new")
|
||||||
|
"""
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.ok = True
|
||||||
|
mock_response.text = diff_text
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
diff = client.fetch_pr_diff("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert diff == diff_text
|
||||||
|
# Verify Accept header was set for diff format
|
||||||
|
call_kwargs = mock_get.call_args.kwargs
|
||||||
|
assert call_kwargs["headers"]["Accept"] == "application/vnd.github.diff"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_pr_diff_failure():
|
||||||
|
"""Test handling diff fetch failure gracefully."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.ok = False
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
diff = client.fetch_pr_diff("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert diff is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_pr_diff_exception():
|
||||||
|
"""Test handling exceptions during diff fetch."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get") as mock_get:
|
||||||
|
mock_get.side_effect = requests.RequestException("Network error")
|
||||||
|
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
diff = client.fetch_pr_diff("owner", "repo", 10)
|
||||||
|
|
||||||
|
assert diff is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for _parse_pr with pr_data
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_pr_fetches_all_pr_data():
|
||||||
|
"""Test that _parse_pr fetches and includes all PR-specific data."""
|
||||||
|
credentials = GithubCredentials(auth_type="pat", access_token="token")
|
||||||
|
|
||||||
|
pr_raw = {
|
||||||
|
"number": 42,
|
||||||
|
"title": "Feature PR",
|
||||||
|
"body": "PR description",
|
||||||
|
"state": "open",
|
||||||
|
"user": {"login": "contributor"},
|
||||||
|
"labels": [{"name": "enhancement"}],
|
||||||
|
"assignees": [{"login": "reviewer"}],
|
||||||
|
"milestone": {"title": "v2.0"},
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"updated_at": "2024-01-02T00:00:00Z",
|
||||||
|
"closed_at": None,
|
||||||
|
"merged_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock responses for all the fetch methods
|
||||||
|
def mock_get(url, **kwargs):
|
||||||
|
response = Mock()
|
||||||
|
response.headers = {"X-RateLimit-Remaining": "4999"}
|
||||||
|
response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
if "/issues/42/comments" in url:
|
||||||
|
# Regular comments
|
||||||
|
page = kwargs.get("params", {}).get("page", 1)
|
||||||
|
if page == 1:
|
||||||
|
response.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"user": {"login": "user1"},
|
||||||
|
"body": "Regular comment",
|
||||||
|
"created_at": "2024-01-01T10:00:00Z",
|
||||||
|
"updated_at": "2024-01-01T10:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/42/comments" in url:
|
||||||
|
# Review comments
|
||||||
|
page = kwargs.get("params", {}).get("page", 1)
|
||||||
|
if page == 1:
|
||||||
|
response.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 101,
|
||||||
|
"user": {"login": "reviewer1"},
|
||||||
|
"body": "Review comment",
|
||||||
|
"path": "src/main.py",
|
||||||
|
"line": 10,
|
||||||
|
"side": "RIGHT",
|
||||||
|
"diff_hunk": "@@ -1,5 +1,10 @@",
|
||||||
|
"created_at": "2024-01-01T12:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/42/reviews" in url:
|
||||||
|
# Reviews
|
||||||
|
page = kwargs.get("params", {}).get("page", 1)
|
||||||
|
if page == 1:
|
||||||
|
response.json.return_value = [
|
||||||
|
{
|
||||||
|
"id": 201,
|
||||||
|
"user": {"login": "lead"},
|
||||||
|
"state": "APPROVED",
|
||||||
|
"body": "LGTM",
|
||||||
|
"submitted_at": "2024-01-02T08:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/42/files" in url:
|
||||||
|
# Files
|
||||||
|
page = kwargs.get("params", {}).get("page", 1)
|
||||||
|
if page == 1:
|
||||||
|
response.json.return_value = [
|
||||||
|
{
|
||||||
|
"filename": "src/main.py",
|
||||||
|
"status": "modified",
|
||||||
|
"additions": 50,
|
||||||
|
"deletions": 10,
|
||||||
|
"patch": "+new\n-old",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "tests/test_main.py",
|
||||||
|
"status": "added",
|
||||||
|
"additions": 30,
|
||||||
|
"deletions": 0,
|
||||||
|
"patch": "+tests",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
elif "/pulls/42" in url and "diff" in kwargs.get("headers", {}).get(
|
||||||
|
"Accept", ""
|
||||||
|
):
|
||||||
|
# Full diff
|
||||||
|
response.ok = True
|
||||||
|
response.text = "diff --git a/src/main.py\n+new code\n-old code"
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
response.json.return_value = []
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
with patch.object(requests.Session, "get", side_effect=mock_get):
|
||||||
|
client = GithubClient(credentials)
|
||||||
|
result = client._parse_pr("owner", "repo", pr_raw)
|
||||||
|
|
||||||
|
# Verify basic fields
|
||||||
|
assert result["kind"] == "pr"
|
||||||
|
assert result["number"] == 42
|
||||||
|
assert result["title"] == "Feature PR"
|
||||||
|
assert result["author"] == "contributor"
|
||||||
|
assert len(result["comments"]) == 1
|
||||||
|
|
||||||
|
# Verify pr_data
|
||||||
|
pr_data = result["pr_data"]
|
||||||
|
assert pr_data is not None
|
||||||
|
|
||||||
|
# Verify diff
|
||||||
|
assert pr_data["diff"] is not None
|
||||||
|
assert "new code" in pr_data["diff"]
|
||||||
|
|
||||||
|
# Verify files
|
||||||
|
assert len(pr_data["files"]) == 2
|
||||||
|
assert pr_data["files"][0]["filename"] == "src/main.py"
|
||||||
|
assert pr_data["files"][0]["additions"] == 50
|
||||||
|
|
||||||
|
# Verify stats calculated from files
|
||||||
|
assert pr_data["additions"] == 80 # 50 + 30
|
||||||
|
assert pr_data["deletions"] == 10
|
||||||
|
assert pr_data["changed_files_count"] == 2
|
||||||
|
|
||||||
|
# Verify reviews
|
||||||
|
assert len(pr_data["reviews"]) == 1
|
||||||
|
assert pr_data["reviews"][0]["user"] == "lead"
|
||||||
|
assert pr_data["reviews"][0]["state"] == "approved"
|
||||||
|
|
||||||
|
# Verify review comments
|
||||||
|
assert len(pr_data["review_comments"]) == 1
|
||||||
|
assert pr_data["review_comments"][0]["user"] == "reviewer1"
|
||||||
|
assert pr_data["review_comments"][0]["path"] == "src/main.py"
|
||||||
|
|
||||||
|
# Verify diff_summary is truncated version of full diff
|
||||||
|
assert result["diff_summary"] == pr_data["diff"][:5000]
|
||||||
|
|||||||
@ -1123,3 +1123,521 @@ def test_tag_merging(repo_tags, issue_labels, expected_tags, github_account, db_
|
|||||||
|
|
||||||
item = db_session.query(GithubItem).filter_by(number=60).first()
|
item = db_session.query(GithubItem).filter_by(number=60).first()
|
||||||
assert item.tags == expected_tags
|
assert item.tags == expected_tags
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for PR data handling
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pr_data_with_extended() -> GithubIssueData:
|
||||||
|
"""Mock PR data with full pr_data dict."""
|
||||||
|
from memory.parsers.github import GithubPRDataDict
|
||||||
|
|
||||||
|
pr_data: GithubPRDataDict = {
|
||||||
|
"diff": "diff --git a/file.py\n+new line\n-old line",
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"filename": "src/main.py",
|
||||||
|
"status": "modified",
|
||||||
|
"additions": 50,
|
||||||
|
"deletions": 10,
|
||||||
|
"patch": "+new\n-old",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "tests/test_main.py",
|
||||||
|
"status": "added",
|
||||||
|
"additions": 30,
|
||||||
|
"deletions": 0,
|
||||||
|
"patch": "+test code",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"additions": 80,
|
||||||
|
"deletions": 10,
|
||||||
|
"changed_files_count": 2,
|
||||||
|
"reviews": [
|
||||||
|
{
|
||||||
|
"id": 1001,
|
||||||
|
"user": "lead_reviewer",
|
||||||
|
"state": "approved",
|
||||||
|
"body": "LGTM!",
|
||||||
|
"submitted_at": "2024-01-02T10:00:00Z",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"review_comments": [
|
||||||
|
{
|
||||||
|
"id": 2001,
|
||||||
|
"user": "reviewer1",
|
||||||
|
"body": "Please add a docstring here",
|
||||||
|
"path": "src/main.py",
|
||||||
|
"line": 42,
|
||||||
|
"side": "RIGHT",
|
||||||
|
"diff_hunk": "@@ -40,3 +40,5 @@",
|
||||||
|
"created_at": "2024-01-01T15:00:00Z",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
return GithubIssueData(
|
||||||
|
kind="pr",
|
||||||
|
number=200,
|
||||||
|
title="Feature: Add new capability",
|
||||||
|
body="This PR adds a new capability to the system.",
|
||||||
|
state="open",
|
||||||
|
author="contributor",
|
||||||
|
labels=["enhancement", "needs-review"],
|
||||||
|
assignees=["reviewer1"],
|
||||||
|
milestone="v2.0",
|
||||||
|
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||||
|
closed_at=None,
|
||||||
|
merged_at=None,
|
||||||
|
github_updated_at=datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc),
|
||||||
|
comment_count=1,
|
||||||
|
comments=[
|
||||||
|
{
|
||||||
|
"id": 5001,
|
||||||
|
"author": "maintainer",
|
||||||
|
"body": "Thanks for the PR!",
|
||||||
|
"created_at": "2024-01-01T14:00:00Z",
|
||||||
|
"updated_at": "2024-01-01T14:00:00Z",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
diff_summary="+new\n-old",
|
||||||
|
project_fields=None,
|
||||||
|
content_hash="pr_extended_hash",
|
||||||
|
pr_data=pr_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_content_with_review_comments(mock_pr_data_with_extended):
|
||||||
|
"""Test _build_content includes review comments for PRs."""
|
||||||
|
content = _build_content(mock_pr_data_with_extended)
|
||||||
|
|
||||||
|
# Basic content
|
||||||
|
assert "# Feature: Add new capability" in content
|
||||||
|
assert "This PR adds a new capability" in content
|
||||||
|
|
||||||
|
# Regular comment
|
||||||
|
assert "**maintainer**: Thanks for the PR!" in content
|
||||||
|
|
||||||
|
# Review comments section
|
||||||
|
assert "## Code Review Comments" in content
|
||||||
|
assert "**reviewer1**" in content
|
||||||
|
assert "Please add a docstring here" in content
|
||||||
|
assert "`src/main.py`" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_content_pr_without_review_comments():
|
||||||
|
"""Test _build_content for PR with no review comments."""
|
||||||
|
data = GithubIssueData(
|
||||||
|
kind="pr",
|
||||||
|
number=201,
|
||||||
|
title="Simple PR",
|
||||||
|
body="Body",
|
||||||
|
state="open",
|
||||||
|
author="user",
|
||||||
|
labels=[],
|
||||||
|
assignees=[],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
closed_at=None,
|
||||||
|
merged_at=None,
|
||||||
|
github_updated_at=datetime.now(timezone.utc),
|
||||||
|
comment_count=0,
|
||||||
|
comments=[],
|
||||||
|
diff_summary=None,
|
||||||
|
project_fields=None,
|
||||||
|
content_hash="hash",
|
||||||
|
pr_data={
|
||||||
|
"diff": None,
|
||||||
|
"files": [],
|
||||||
|
"additions": 0,
|
||||||
|
"deletions": 0,
|
||||||
|
"changed_files_count": 0,
|
||||||
|
"reviews": [],
|
||||||
|
"review_comments": [], # Empty
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
content = _build_content(data)
|
||||||
|
|
||||||
|
assert "# Simple PR" in content
|
||||||
|
assert "## Code Review Comments" not in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_content_issue_no_pr_data():
|
||||||
|
"""Test _build_content for issue (no pr_data)."""
|
||||||
|
data = GithubIssueData(
|
||||||
|
kind="issue",
|
||||||
|
number=100,
|
||||||
|
title="Bug Report",
|
||||||
|
body="There's a bug",
|
||||||
|
state="open",
|
||||||
|
author="reporter",
|
||||||
|
labels=["bug"],
|
||||||
|
assignees=[],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
closed_at=None,
|
||||||
|
merged_at=None,
|
||||||
|
github_updated_at=datetime.now(timezone.utc),
|
||||||
|
comment_count=0,
|
||||||
|
comments=[],
|
||||||
|
diff_summary=None,
|
||||||
|
project_fields=None,
|
||||||
|
content_hash="hash",
|
||||||
|
pr_data=None, # Issues don't have pr_data
|
||||||
|
)
|
||||||
|
|
||||||
|
content = _build_content(data)
|
||||||
|
|
||||||
|
assert "# Bug Report" in content
|
||||||
|
assert "There's a bug" in content
|
||||||
|
assert "## Code Review Comments" not in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_pr_data_function(mock_pr_data_with_extended):
|
||||||
|
"""Test _create_pr_data creates GithubPRData correctly."""
|
||||||
|
from memory.workers.tasks.github import _create_pr_data
|
||||||
|
|
||||||
|
result = _create_pr_data(mock_pr_data_with_extended)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.additions == 80
|
||||||
|
assert result.deletions == 10
|
||||||
|
assert result.changed_files_count == 2
|
||||||
|
|
||||||
|
# Files are stored as JSONB
|
||||||
|
assert len(result.files) == 2
|
||||||
|
assert result.files[0]["filename"] == "src/main.py"
|
||||||
|
|
||||||
|
# Reviews
|
||||||
|
assert len(result.reviews) == 1
|
||||||
|
assert result.reviews[0]["user"] == "lead_reviewer"
|
||||||
|
assert result.reviews[0]["state"] == "approved"
|
||||||
|
|
||||||
|
# Review comments
|
||||||
|
assert len(result.review_comments) == 1
|
||||||
|
assert result.review_comments[0]["path"] == "src/main.py"
|
||||||
|
|
||||||
|
# Diff is compressed - test the property getter
|
||||||
|
assert result.diff is not None
|
||||||
|
assert "new line" in result.diff
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_pr_data_none_for_issue():
|
||||||
|
"""Test _create_pr_data returns None for issues."""
|
||||||
|
from memory.workers.tasks.github import _create_pr_data
|
||||||
|
|
||||||
|
data = GithubIssueData(
|
||||||
|
kind="issue",
|
||||||
|
number=100,
|
||||||
|
title="Issue",
|
||||||
|
body="Body",
|
||||||
|
state="open",
|
||||||
|
author="user",
|
||||||
|
labels=[],
|
||||||
|
assignees=[],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
closed_at=None,
|
||||||
|
merged_at=None,
|
||||||
|
github_updated_at=datetime.now(timezone.utc),
|
||||||
|
comment_count=0,
|
||||||
|
comments=[],
|
||||||
|
diff_summary=None,
|
||||||
|
project_fields=None,
|
||||||
|
content_hash="hash",
|
||||||
|
pr_data=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _create_pr_data(data)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_deserialize_with_pr_data(mock_pr_data_with_extended):
|
||||||
|
"""Test serialization roundtrip preserves pr_data."""
|
||||||
|
serialized = _serialize_issue_data(mock_pr_data_with_extended)
|
||||||
|
|
||||||
|
# Verify pr_data is included in serialized form
|
||||||
|
assert "pr_data" in serialized
|
||||||
|
assert serialized["pr_data"]["additions"] == 80
|
||||||
|
assert len(serialized["pr_data"]["files"]) == 2
|
||||||
|
assert len(serialized["pr_data"]["reviews"]) == 1
|
||||||
|
assert len(serialized["pr_data"]["review_comments"]) == 1
|
||||||
|
|
||||||
|
# Deserialize and verify
|
||||||
|
deserialized = _deserialize_issue_data(serialized)
|
||||||
|
|
||||||
|
assert deserialized["pr_data"] is not None
|
||||||
|
assert deserialized["pr_data"]["additions"] == 80
|
||||||
|
assert deserialized["pr_data"]["deletions"] == 10
|
||||||
|
assert len(deserialized["pr_data"]["files"]) == 2
|
||||||
|
assert deserialized["pr_data"]["diff"] == mock_pr_data_with_extended["pr_data"]["diff"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_deserialize_without_pr_data(mock_issue_data):
|
||||||
|
"""Test serialization roundtrip for issue without pr_data."""
|
||||||
|
# Add pr_data=None to the mock (issues don't have it)
|
||||||
|
issue_with_none = dict(mock_issue_data)
|
||||||
|
issue_with_none["pr_data"] = None
|
||||||
|
|
||||||
|
serialized = _serialize_issue_data(issue_with_none)
|
||||||
|
assert serialized.get("pr_data") is None
|
||||||
|
|
||||||
|
deserialized = _deserialize_issue_data(serialized)
|
||||||
|
assert deserialized.get("pr_data") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_github_item_creates_pr_data(
|
||||||
|
mock_pr_data_with_extended, github_repo, db_session, qdrant
|
||||||
|
):
|
||||||
|
"""Test that syncing a PR creates associated GithubPRData."""
|
||||||
|
serialized = _serialize_issue_data(mock_pr_data_with_extended)
|
||||||
|
result = github.sync_github_item(github_repo.id, serialized)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
# Query the created item
|
||||||
|
item = (
|
||||||
|
db_session.query(GithubItem)
|
||||||
|
.filter_by(repo_path="testorg/testrepo", number=200, kind="pr")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert item is not None
|
||||||
|
assert item.kind == "pr"
|
||||||
|
|
||||||
|
# Check pr_data relationship
|
||||||
|
assert item.pr_data is not None
|
||||||
|
assert item.pr_data.additions == 80
|
||||||
|
assert item.pr_data.deletions == 10
|
||||||
|
assert item.pr_data.changed_files_count == 2
|
||||||
|
assert len(item.pr_data.files) == 2
|
||||||
|
assert len(item.pr_data.reviews) == 1
|
||||||
|
assert len(item.pr_data.review_comments) == 1
|
||||||
|
|
||||||
|
# Verify diff decompression works
|
||||||
|
assert item.pr_data.diff is not None
|
||||||
|
assert "new line" in item.pr_data.diff
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_github_item_pr_without_pr_data(github_repo, db_session, qdrant):
|
||||||
|
"""Test syncing a PR that doesn't have extended pr_data."""
|
||||||
|
data = GithubIssueData(
|
||||||
|
kind="pr",
|
||||||
|
number=202,
|
||||||
|
title="Legacy PR",
|
||||||
|
body="PR without extended data",
|
||||||
|
state="merged",
|
||||||
|
author="old_contributor",
|
||||||
|
labels=[],
|
||||||
|
assignees=[],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||||
|
closed_at=datetime(2024, 1, 5, tzinfo=timezone.utc),
|
||||||
|
merged_at=datetime(2024, 1, 5, tzinfo=timezone.utc),
|
||||||
|
github_updated_at=datetime(2024, 1, 5, tzinfo=timezone.utc),
|
||||||
|
comment_count=0,
|
||||||
|
comments=[],
|
||||||
|
diff_summary="+10 -5",
|
||||||
|
project_fields=None,
|
||||||
|
content_hash="legacy_hash",
|
||||||
|
pr_data=None, # No extended PR data
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = _serialize_issue_data(data)
|
||||||
|
result = github.sync_github_item(github_repo.id, serialized)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
item = db_session.query(GithubItem).filter_by(number=202).first()
|
||||||
|
assert item is not None
|
||||||
|
assert item.kind == "pr"
|
||||||
|
assert item.pr_data is None # No pr_data created
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_github_item_updates_existing_pr_data(github_repo, db_session, qdrant):
|
||||||
|
"""Test updating an existing PR with new pr_data."""
|
||||||
|
from memory.common.db.models import GithubPRData
|
||||||
|
from memory.workers.tasks.content_processing import create_content_hash
|
||||||
|
|
||||||
|
# Create initial PR with pr_data
|
||||||
|
initial_content = "# Initial PR\n\nOriginal body"
|
||||||
|
existing_item = GithubItem(
|
||||||
|
repo_path="testorg/testrepo",
|
||||||
|
repo_id=github_repo.id,
|
||||||
|
number=300,
|
||||||
|
kind="pr",
|
||||||
|
title="Initial PR",
|
||||||
|
content=initial_content,
|
||||||
|
state="open",
|
||||||
|
author="user",
|
||||||
|
labels=[],
|
||||||
|
assignees=[],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||||
|
github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||||
|
comment_count=0,
|
||||||
|
content_hash="initial_hash",
|
||||||
|
diff_summary="+5 -2",
|
||||||
|
modality="github",
|
||||||
|
mime_type="text/markdown",
|
||||||
|
sha256=create_content_hash(initial_content),
|
||||||
|
size=len(initial_content),
|
||||||
|
tags=["github", "test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create initial pr_data
|
||||||
|
initial_pr_data = GithubPRData(
|
||||||
|
additions=5,
|
||||||
|
deletions=2,
|
||||||
|
changed_files_count=1,
|
||||||
|
files=[{"filename": "old.py", "status": "modified", "additions": 5, "deletions": 2, "patch": None}],
|
||||||
|
reviews=[],
|
||||||
|
review_comments=[],
|
||||||
|
)
|
||||||
|
initial_pr_data.diff = "old diff"
|
||||||
|
existing_item.pr_data = initial_pr_data
|
||||||
|
|
||||||
|
db_session.add(existing_item)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Now update with new data
|
||||||
|
updated_data = GithubIssueData(
|
||||||
|
kind="pr",
|
||||||
|
number=300,
|
||||||
|
title="Updated PR",
|
||||||
|
body="Updated body with more changes",
|
||||||
|
state="open",
|
||||||
|
author="user",
|
||||||
|
labels=["ready-for-review"],
|
||||||
|
assignees=["reviewer"],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||||
|
closed_at=None,
|
||||||
|
merged_at=None,
|
||||||
|
github_updated_at=datetime(2024, 1, 5, tzinfo=timezone.utc), # Newer
|
||||||
|
comment_count=2,
|
||||||
|
comments=[
|
||||||
|
{"id": 1, "author": "reviewer", "body": "LGTM", "created_at": "", "updated_at": ""}
|
||||||
|
],
|
||||||
|
diff_summary="+50 -10",
|
||||||
|
project_fields=None,
|
||||||
|
content_hash="updated_hash", # Different hash triggers update
|
||||||
|
pr_data={
|
||||||
|
"diff": "new diff with lots of changes",
|
||||||
|
"files": [
|
||||||
|
{"filename": "new.py", "status": "added", "additions": 50, "deletions": 0, "patch": "+code"},
|
||||||
|
{"filename": "old.py", "status": "modified", "additions": 0, "deletions": 10, "patch": "-code"},
|
||||||
|
],
|
||||||
|
"additions": 50,
|
||||||
|
"deletions": 10,
|
||||||
|
"changed_files_count": 2,
|
||||||
|
"reviews": [
|
||||||
|
{"id": 1, "user": "reviewer", "state": "approved", "body": "Approved!", "submitted_at": ""}
|
||||||
|
],
|
||||||
|
"review_comments": [
|
||||||
|
{"id": 1, "user": "reviewer", "body": "Nice!", "path": "new.py", "line": 10, "side": "RIGHT", "diff_hunk": "", "created_at": ""}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = _serialize_issue_data(updated_data)
|
||||||
|
result = github.sync_github_item(github_repo.id, serialized)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
# Refresh from DB
|
||||||
|
db_session.expire_all()
|
||||||
|
item = db_session.query(GithubItem).filter_by(number=300).first()
|
||||||
|
|
||||||
|
assert item.title == "Updated PR"
|
||||||
|
assert item.pr_data is not None
|
||||||
|
assert item.pr_data.additions == 50
|
||||||
|
assert item.pr_data.deletions == 10
|
||||||
|
assert item.pr_data.changed_files_count == 2
|
||||||
|
assert len(item.pr_data.files) == 2
|
||||||
|
assert len(item.pr_data.reviews) == 1
|
||||||
|
assert len(item.pr_data.review_comments) == 1
|
||||||
|
assert "new diff" in item.pr_data.diff
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_github_item_creates_pr_data_for_existing_pr_without(
|
||||||
|
github_repo, db_session, qdrant
|
||||||
|
):
|
||||||
|
"""Test updating a PR that didn't have pr_data to add it."""
|
||||||
|
from memory.workers.tasks.content_processing import create_content_hash
|
||||||
|
|
||||||
|
# Create existing PR without pr_data (legacy data)
|
||||||
|
initial_content = "# Legacy PR\n\nOriginal"
|
||||||
|
existing_item = GithubItem(
|
||||||
|
repo_path="testorg/testrepo",
|
||||||
|
repo_id=github_repo.id,
|
||||||
|
number=301,
|
||||||
|
kind="pr",
|
||||||
|
title="Legacy PR",
|
||||||
|
content=initial_content,
|
||||||
|
state="open",
|
||||||
|
author="user",
|
||||||
|
labels=[],
|
||||||
|
assignees=[],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||||
|
github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||||
|
comment_count=0,
|
||||||
|
content_hash="legacy_hash",
|
||||||
|
diff_summary=None,
|
||||||
|
modality="github",
|
||||||
|
mime_type="text/markdown",
|
||||||
|
sha256=create_content_hash(initial_content),
|
||||||
|
size=len(initial_content),
|
||||||
|
tags=["github"],
|
||||||
|
pr_data=None, # No pr_data initially
|
||||||
|
)
|
||||||
|
db_session.add(existing_item)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update with pr_data
|
||||||
|
updated_data = GithubIssueData(
|
||||||
|
kind="pr",
|
||||||
|
number=301,
|
||||||
|
title="Legacy PR",
|
||||||
|
body="Original with new review",
|
||||||
|
state="open",
|
||||||
|
author="user",
|
||||||
|
labels=[],
|
||||||
|
assignees=[],
|
||||||
|
milestone=None,
|
||||||
|
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||||
|
closed_at=None,
|
||||||
|
merged_at=None,
|
||||||
|
github_updated_at=datetime(2024, 1, 2, tzinfo=timezone.utc),
|
||||||
|
comment_count=0,
|
||||||
|
comments=[],
|
||||||
|
diff_summary="+10 -0",
|
||||||
|
project_fields=None,
|
||||||
|
content_hash="new_hash", # Different
|
||||||
|
pr_data={
|
||||||
|
"diff": "the full diff",
|
||||||
|
"files": [{"filename": "new.py", "status": "added", "additions": 10, "deletions": 0, "patch": None}],
|
||||||
|
"additions": 10,
|
||||||
|
"deletions": 0,
|
||||||
|
"changed_files_count": 1,
|
||||||
|
"reviews": [],
|
||||||
|
"review_comments": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = _serialize_issue_data(updated_data)
|
||||||
|
result = github.sync_github_item(github_repo.id, serialized)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
db_session.expire_all()
|
||||||
|
item = db_session.query(GithubItem).filter_by(number=301).first()
|
||||||
|
|
||||||
|
# Now should have pr_data
|
||||||
|
assert item.pr_data is not None
|
||||||
|
assert item.pr_data.additions == 10
|
||||||
|
assert item.pr_data.diff == "the full diff"
|
||||||
|
|||||||
404
tests/memory/workers/tasks/test_people_tasks.py
Normal file
404
tests/memory/workers/tasks/test_people_tasks.py
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
"""Tests for people Celery tasks."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from memory.common.db.models import Person
|
||||||
|
from memory.common.db.models.source_item import Chunk
|
||||||
|
from memory.workers.tasks import people
|
||||||
|
from memory.workers.tasks.content_processing import create_content_hash
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_chunk(source_id: int) -> Chunk:
|
||||||
|
"""Create a mock chunk for testing with a unique ID."""
|
||||||
|
return Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
content="test chunk content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.1] * 1024,
|
||||||
|
item_metadata={"source_id": source_id, "tags": ["test"]},
|
||||||
|
collection_name="person",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_make_session(db_session):
|
||||||
|
"""Mock make_session and embedding functions for task tests."""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _mock_session():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
with patch("memory.workers.tasks.people.make_session", _mock_session):
|
||||||
|
# Mock embedding to return a fake chunk
|
||||||
|
with patch(
|
||||||
|
"memory.common.embedding.embed_source_item",
|
||||||
|
side_effect=lambda item: [_make_mock_chunk(item.id or 1)],
|
||||||
|
):
|
||||||
|
# Mock push_to_qdrant to do nothing
|
||||||
|
with patch("memory.workers.tasks.content_processing.push_to_qdrant"):
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def person_data():
|
||||||
|
"""Standard person test data."""
|
||||||
|
return {
|
||||||
|
"identifier": "alice_chen",
|
||||||
|
"display_name": "Alice Chen",
|
||||||
|
"aliases": ["@alice_c", "alice.chen@work.com"],
|
||||||
|
"contact_info": {"email": "alice@example.com", "phone": "555-1234"},
|
||||||
|
"tags": ["work", "engineering"],
|
||||||
|
"notes": "Tech lead on Platform team.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def minimal_person_data():
|
||||||
|
"""Minimal person test data."""
|
||||||
|
return {
|
||||||
|
"identifier": "bob_smith",
|
||||||
|
"display_name": "Bob Smith",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_person_success(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test successful person sync."""
|
||||||
|
result = people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Verify the Person was created in the database
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
assert person is not None
|
||||||
|
assert person.identifier == "alice_chen"
|
||||||
|
assert person.display_name == "Alice Chen"
|
||||||
|
assert person.aliases == ["@alice_c", "alice.chen@work.com"]
|
||||||
|
assert person.contact_info == {"email": "alice@example.com", "phone": "555-1234"}
|
||||||
|
assert person.tags == ["work", "engineering"]
|
||||||
|
assert person.content == "Tech lead on Platform team."
|
||||||
|
assert person.modality == "person"
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
assert "person_id" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_person_minimal_data(minimal_person_data, mock_make_session, qdrant):
|
||||||
|
"""Test person sync with minimal required data."""
|
||||||
|
result = people.sync_person(**minimal_person_data)
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="bob_smith").first()
|
||||||
|
assert person is not None
|
||||||
|
assert person.identifier == "bob_smith"
|
||||||
|
assert person.display_name == "Bob Smith"
|
||||||
|
assert person.aliases == []
|
||||||
|
assert person.contact_info == {}
|
||||||
|
assert person.tags == []
|
||||||
|
assert person.content is None
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_person_already_exists(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test sync when person already exists."""
|
||||||
|
# Create the person first
|
||||||
|
sha256 = create_content_hash(f"person:{person_data['identifier']}")
|
||||||
|
existing_person = Person(
|
||||||
|
identifier=person_data["identifier"],
|
||||||
|
display_name=person_data["display_name"],
|
||||||
|
aliases=person_data["aliases"],
|
||||||
|
contact_info=person_data["contact_info"],
|
||||||
|
tags=person_data["tags"],
|
||||||
|
content=person_data["notes"],
|
||||||
|
modality="person",
|
||||||
|
mime_type="text/plain",
|
||||||
|
sha256=sha256,
|
||||||
|
size=len(person_data["notes"]),
|
||||||
|
)
|
||||||
|
mock_make_session.add(existing_person)
|
||||||
|
mock_make_session.commit()
|
||||||
|
|
||||||
|
# Try to sync again
|
||||||
|
result = people.sync_person(**person_data)
|
||||||
|
|
||||||
|
assert result["status"] == "already_exists"
|
||||||
|
assert result["person_id"] == existing_person.id
|
||||||
|
|
||||||
|
# Verify no duplicate was created
|
||||||
|
count = mock_make_session.query(Person).filter_by(identifier="alice_chen").count()
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_display_name(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test updating display name."""
|
||||||
|
# Create person first
|
||||||
|
people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Update display name
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
display_name="Alice M. Chen",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
assert person.display_name == "Alice M. Chen"
|
||||||
|
# Other fields unchanged
|
||||||
|
assert person.aliases == ["@alice_c", "alice.chen@work.com"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_merge_aliases(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test that aliases are merged, not replaced."""
|
||||||
|
# Create person first
|
||||||
|
people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Update with new aliases
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
aliases=["@alice_chen", "alice@company.com"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
# Should be union of old and new
|
||||||
|
assert set(person.aliases) == {
|
||||||
|
"@alice_c",
|
||||||
|
"alice.chen@work.com",
|
||||||
|
"@alice_chen",
|
||||||
|
"alice@company.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_merge_contact_info(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test that contact_info is deep merged."""
|
||||||
|
# Create person first
|
||||||
|
people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Update with new contact info
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
contact_info={"twitter": "@alice_c", "phone": "555-5678"}, # Update existing
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
# Should have all keys
|
||||||
|
assert person.contact_info["email"] == "alice@example.com" # Original
|
||||||
|
assert person.contact_info["phone"] == "555-5678" # Updated
|
||||||
|
assert person.contact_info["twitter"] == "@alice_c" # New
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_merge_tags(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test that tags are merged, not replaced."""
|
||||||
|
# Create person first
|
||||||
|
people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Update with new tags
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
tags=["climbing", "london"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
# Should be union of old and new
|
||||||
|
assert set(person.tags) == {"work", "engineering", "climbing", "london"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_append_notes(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test that notes are appended by default."""
|
||||||
|
# Create person first
|
||||||
|
people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Update with new notes
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
notes="Also enjoys rock climbing.",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
# Should be appended with separator
|
||||||
|
assert "Tech lead on Platform team." in person.content
|
||||||
|
assert "Also enjoys rock climbing." in person.content
|
||||||
|
assert "---" in person.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_replace_notes(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test replacing notes instead of appending."""
|
||||||
|
# Create person first
|
||||||
|
people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Replace notes
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="alice_chen",
|
||||||
|
notes="Completely new notes.",
|
||||||
|
replace_notes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
assert person.content == "Completely new notes."
|
||||||
|
assert "Tech lead" not in person.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_not_found(mock_make_session, qdrant):
|
||||||
|
"""Test updating a person that doesn't exist."""
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="nonexistent_person",
|
||||||
|
display_name="New Name",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "not_found"
|
||||||
|
assert result["identifier"] == "nonexistent_person"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_no_changes(person_data, mock_make_session, qdrant):
|
||||||
|
"""Test update with no actual changes."""
|
||||||
|
# Create person first
|
||||||
|
people.sync_person(**person_data)
|
||||||
|
|
||||||
|
# Update with nothing
|
||||||
|
result = people.update_person(identifier="alice_chen")
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="alice_chen").first()
|
||||||
|
# Should be unchanged
|
||||||
|
assert person.display_name == "Alice Chen"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"identifier,display_name,tags",
|
||||||
|
[
|
||||||
|
("john_doe", "John Doe", []),
|
||||||
|
("jane_smith", "Jane Smith", ["friend"]),
|
||||||
|
("bob_jones", "Bob Jones", ["work", "climbing", "london"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sync_person_various_configurations(identifier, display_name, tags, mock_make_session, qdrant):
|
||||||
|
"""Test sync_person with various configurations."""
|
||||||
|
result = people.sync_person(
|
||||||
|
identifier=identifier,
|
||||||
|
display_name=display_name,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier=identifier).first()
|
||||||
|
assert person is not None
|
||||||
|
assert person.display_name == display_name
|
||||||
|
assert person.tags == tags
|
||||||
|
|
||||||
|
|
||||||
|
def test_deep_merge_helper():
|
||||||
|
"""Test the _deep_merge helper function."""
|
||||||
|
base = {
|
||||||
|
"a": 1,
|
||||||
|
"b": {"c": 2, "d": 3},
|
||||||
|
"e": 4,
|
||||||
|
}
|
||||||
|
updates = {
|
||||||
|
"b": {"c": 5, "f": 6},
|
||||||
|
"g": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = people._deep_merge(base, updates)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"a": 1,
|
||||||
|
"b": {"c": 5, "d": 3, "f": 6},
|
||||||
|
"e": 4,
|
||||||
|
"g": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_deep_merge_nested():
|
||||||
|
"""Test deep merge with deeply nested structures."""
|
||||||
|
base = {
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"level3": {"a": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
updates = {
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"level3": {"b": 2},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = people._deep_merge(base, updates)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"level3": {"a": 1, "b": 2},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_person_unicode(mock_make_session, qdrant):
|
||||||
|
"""Test sync_person with unicode content."""
|
||||||
|
result = people.sync_person(
|
||||||
|
identifier="unicode_person",
|
||||||
|
display_name="日本語 名前",
|
||||||
|
notes="Привет мир 🌍",
|
||||||
|
tags=["日本語", "emoji"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="unicode_person").first()
|
||||||
|
assert person is not None
|
||||||
|
assert person.display_name == "日本語 名前"
|
||||||
|
assert person.content == "Привет мир 🌍"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_person_empty_notes(mock_make_session, qdrant):
|
||||||
|
"""Test sync_person with empty notes."""
|
||||||
|
result = people.sync_person(
|
||||||
|
identifier="empty_notes_person",
|
||||||
|
display_name="Empty Notes Person",
|
||||||
|
notes="",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="empty_notes_person").first()
|
||||||
|
assert person is not None
|
||||||
|
assert person.content == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_person_first_notes(mock_make_session, qdrant):
|
||||||
|
"""Test adding notes to a person who had no notes."""
|
||||||
|
# Create person without notes
|
||||||
|
people.sync_person(
|
||||||
|
identifier="no_notes_person",
|
||||||
|
display_name="No Notes Person",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add notes
|
||||||
|
result = people.update_person(
|
||||||
|
identifier="no_notes_person",
|
||||||
|
notes="First notes!",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
person = mock_make_session.query(Person).filter_by(identifier="no_notes_person").first()
|
||||||
|
assert person.content == "First notes!"
|
||||||
|
# Should not have separator when there were no previous notes
|
||||||
|
assert "---" not in person.content
|
||||||
Loading…
x
Reference in New Issue
Block a user