From 526bfa5f6bcc4d76e73442b0c8873cc77b8384d6 Mon Sep 17 00:00:00 2001 From: mruwnik Date: Tue, 23 Dec 2025 20:02:10 +0000 Subject: [PATCH] more github ingesting --- docker-compose.yaml | 2 +- docker/workers/Dockerfile | 2 +- src/memory/api/admin.py | 5 +- src/memory/common/celery_app.py | 7 + src/memory/common/collections.py | 6 + src/memory/common/db/models/source_items.py | 40 + src/memory/workers/tasks/__init__.py | 2 + tests/memory/parsers/test_github.py | 705 +++++++++++ .../memory/workers/tasks/test_github_tasks.py | 1125 +++++++++++++++++ tools/run_celery_task.py | 39 +- 10 files changed, 1924 insertions(+), 9 deletions(-) create mode 100644 tests/memory/parsers/test_github.py create mode 100644 tests/memory/workers/tasks/test_github_tasks.py diff --git a/docker-compose.yaml b/docker-compose.yaml index fa7409c..f9ca190 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -206,7 +206,7 @@ services: <<: *worker-base environment: <<: *worker-env - QUEUES: "backup,email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler" + QUEUES: "backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance,notes,scheduler" ingest-hub: <<: *worker-base diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index b3a0e5c..fa64a97 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \ git config --global user.name "${GIT_USER_NAME}" # Default queues to process -ENV QUEUES="backup,ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance" +ENV QUEUES="backup,blogs,comic,discord,ebooks,email,forums,github,photo_embed,maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["./entry.sh"] \ No newline at end of file diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index 0d92861..af15ff1 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -351,9 +351,8 @@ class GithubAccountAdmin(ModelView, model=GithubAccount): "updated_at", ] column_searchable_list = ["name", "id"] - # Hide sensitive columns from display - column_exclude_list = ["access_token", "private_key"] - form_excluded_columns = ["repos"] + # Sensitive columns (access_token, private_key) are already excluded from column_list + form_excluded_columns = ["repos", "access_token", "private_key"] class GithubRepoAdmin(ModelView, model=GithubRepo): diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index f2a4dd0..f4c2c1f 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -1,4 +1,5 @@ from celery import Celery +from celery.schedules import crontab from kombu.utils.url import safequote from memory.common import settings @@ -123,6 +124,12 @@ app.conf.update( f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"}, f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"}, }, + beat_schedule={ + "sync-github-repos-hourly": { + "task": SYNC_ALL_GITHUB_REPOS, + "schedule": crontab(minute=0), # Every hour at :00 + }, + }, ) diff --git a/src/memory/common/collections.py b/src/memory/common/collections.py index cc7a71a..e916672 100644 --- a/src/memory/common/collections.py +++ b/src/memory/common/collections.py @@ -58,6 +58,12 @@ ALL_COLLECTIONS: dict[str, Collection] = { "text": True, "multimodal": True, }, + "github": { + "dimension": 1024, + "distance": "Cosine", + "text": True, + "multimodal": False, + }, "text": { "dimension": 1024, "distance": "Cosine", diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index e87e17f..e3ff57a 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -802,6 +802,23 @@ class MiscDoc(SourceItem): } +class GithubItemPayload(SourceItemPayload): + kind: Annotated[str, "Type: issue, pr, comment, or project_card"] + repo_path: Annotated[str, "Repository path (owner/name)"] + number: Annotated[int | None, "Issue or PR number"] + state: Annotated[str | None, "State: open, closed, merged"] + title: Annotated[str | None, "Issue or PR title"] + author: Annotated[str | None, "Author username"] + labels: Annotated[list[str] | None, "GitHub labels"] + assignees: Annotated[list[str] | None, "Assigned users"] + milestone: Annotated[str | None, "Milestone name"] + project_status: Annotated[str | None, "GitHub Project status"] + project_priority: Annotated[str | None, "GitHub Project priority"] + created_at: Annotated[datetime | None, "Creation date"] + closed_at: Annotated[datetime | None, "Close date"] + merged_at: Annotated[datetime | None, "Merge date (PRs only)"] + + class GithubItem(SourceItem): __tablename__ = "github_item" @@ -854,6 +871,29 @@ class GithubItem(SourceItem): Index("gh_repo_id_idx", "repo_id"), ) + @classmethod + def get_collections(cls) -> list[str]: + return ["github"] + + def as_payload(self) -> GithubItemPayload: + return GithubItemPayload( + **super().as_payload(), + kind=cast(str, self.kind), + repo_path=cast(str, self.repo_path), + number=cast(int | None, self.number), + state=cast(str | None, self.state), + title=cast(str | None, self.title), + author=cast(str | None, self.author), + labels=cast(list[str] | None, self.labels), + assignees=cast(list[str] | None, self.assignees), + milestone=cast(str | None, self.milestone), + project_status=cast(str | None, self.project_status), + project_priority=cast(str | None, self.project_priority), + created_at=cast(datetime | None, self.created_at), + closed_at=cast(datetime | None, self.closed_at), + merged_at=cast(datetime | None, self.merged_at), + ) + class NotePayload(SourceItemPayload): note_type: Annotated[str | None, "Category of the note"] diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index 41c142d..33f22aa 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -10,6 +10,7 @@ from memory.workers.tasks import ( ebook, email, forums, + github, maintenance, notes, observations, @@ -24,6 +25,7 @@ __all__ = [ "ebook", "discord", "forums", + "github", "maintenance", "notes", "observations", diff --git a/tests/memory/parsers/test_github.py b/tests/memory/parsers/test_github.py new file mode 100644 index 0000000..bc11eeb --- /dev/null +++ b/tests/memory/parsers/test_github.py @@ -0,0 +1,705 @@ +"""Tests for GitHub API client and parser.""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import Mock, patch, MagicMock +import requests + +from memory.parsers.github import ( + GithubCredentials, + GithubClient, + GithubIssueData, + GithubComment, + parse_github_date, + compute_content_hash, +) + + +# ============================================================================= +# Tests for utility functions +# ============================================================================= + + +@pytest.mark.parametrize( + "date_str,expected", + [ + ("2024-01-15T10:30:00Z", datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc)), + ( + "2024-06-20T14:45:30Z", + datetime(2024, 6, 20, 14, 45, 30, tzinfo=timezone.utc), + ), + (None, None), + ("", None), + ], +) +def test_parse_github_date(date_str, expected): + """Test parsing GitHub date strings.""" + result = parse_github_date(date_str) + assert result == expected + + +def test_compute_content_hash_body_only(): + """Test content hash with body only.""" + hash1 = compute_content_hash("This is the body", []) + hash2 = compute_content_hash("This is the body", []) + hash3 = compute_content_hash("Different body", []) + + assert hash1 == hash2 # Same content = same hash + assert hash1 != hash3 # Different content = different hash + + +def test_compute_content_hash_with_comments(): + """Test content hash includes comments.""" + comments = [ + GithubComment( + id=1, + author="user1", + body="First comment", + created_at="2024-01-01T00:00:00Z", + updated_at="2024-01-01T00:00:00Z", + ), + GithubComment( + id=2, + author="user2", + body="Second comment", + created_at="2024-01-02T00:00:00Z", + updated_at="2024-01-02T00:00:00Z", + ), + ] + + hash_with_comments = compute_content_hash("Body", comments) + hash_without_comments = compute_content_hash("Body", []) + + assert hash_with_comments != hash_without_comments + + +def test_compute_content_hash_empty_body(): + """Test content hash with empty/None body.""" + hash1 = compute_content_hash("", []) + hash2 = compute_content_hash(None, []) # type: ignore + + # Both should produce valid hashes + assert len(hash1) == 64 # SHA256 hex + assert len(hash2) == 64 + + +def test_compute_content_hash_comment_order_matters(): + """Test that comment order affects the hash.""" + comment1 = GithubComment( + id=1, author="a", body="First", created_at="", updated_at="" + ) + comment2 = GithubComment( + id=2, author="b", body="Second", created_at="", updated_at="" + ) + + hash_order1 = compute_content_hash("Body", [comment1, comment2]) + hash_order2 = compute_content_hash("Body", [comment2, comment1]) + + assert hash_order1 != hash_order2 + + +# ============================================================================= +# Tests for GithubClient initialization +# ============================================================================= + + +def test_github_client_pat_auth(): + """Test client initialization with PAT authentication.""" + credentials = GithubCredentials( + auth_type="pat", + access_token="ghp_test_token", + ) + + with patch.object(requests.Session, "get"): + client = GithubClient(credentials) + + assert "Bearer ghp_test_token" in client.session.headers["Authorization"] + assert client.session.headers["Accept"] == "application/vnd.github+json" + assert client.session.headers["X-GitHub-Api-Version"] == "2022-11-28" + + +# ============================================================================= +# Tests for fetch_issues +# ============================================================================= + + +def test_fetch_issues_basic(): + """Test fetching issues from repository.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + def mock_get(url, **kwargs): + """Route mock responses based on URL.""" + response = Mock() + response.headers = {"X-RateLimit-Remaining": "4999"} + response.raise_for_status = Mock() + + page = kwargs.get("params", {}).get("page", 1) + + if "/repos/" in url and "/issues" in url and "/comments" not in url: + # Issues endpoint + if page == 1: + response.json.return_value = [ + { + "number": 1, + "title": "Test Issue", + "body": "Issue body", + "state": "open", + "user": {"login": "testuser"}, + "labels": [{"name": "bug"}], + "assignees": [{"login": "dev1"}], + "milestone": {"title": "v1.0"}, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "closed_at": None, + "comments": 2, + # Note: Do NOT include "pull_request" key for real issues + # The API checks `if "pull_request" in issue` to skip PRs + } + ] + else: + response.json.return_value = [] + elif "/comments" in url: + # Comments endpoint + if page == 1: + response.json.return_value = [ + { + "id": 100, + "user": {"login": "commenter"}, + "body": "A comment", + "created_at": "2024-01-01T12:00:00Z", + "updated_at": "2024-01-01T12:00:00Z", + } + ] + else: + response.json.return_value = [] + else: + response.json.return_value = [] + + return response + + with patch.object(requests.Session, "get", side_effect=mock_get): + client = GithubClient(credentials) + issues = list(client.fetch_issues("owner", "repo")) + + assert len(issues) == 1 + issue = issues[0] + assert issue["number"] == 1 + assert issue["title"] == "Test Issue" + assert issue["kind"] == "issue" + assert issue["state"] == "open" + assert issue["author"] == "testuser" + assert issue["labels"] == ["bug"] + assert issue["assignees"] == ["dev1"] + assert issue["milestone"] == "v1.0" + assert len(issue["comments"]) == 1 + + +def test_fetch_issues_skips_prs(): + """Test that PRs in issue list are skipped.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + def mock_get(url, **kwargs): + """Route mock responses based on URL.""" + response = Mock() + response.headers = {"X-RateLimit-Remaining": "4999"} + response.raise_for_status = Mock() + + page = kwargs.get("params", {}).get("page", 1) + + if "/repos/" in url and "/issues" in url and "/comments" not in url: + if page == 1: + response.json.return_value = [ + { + "number": 1, + "title": "Issue", + "body": "Body", + "state": "open", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + "closed_at": None, + "comments": 0, + # Real issues don't have "pull_request" key + }, + { + "number": 2, + "title": "PR posing as issue", + "body": "Body", + "state": "open", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + "closed_at": None, + "comments": 0, + "pull_request": {"url": "https://..."}, # PRs have this key + }, + ] + else: + response.json.return_value = [] + elif "/comments" in url: + response.json.return_value = [] + else: + response.json.return_value = [] + + return response + + with patch.object(requests.Session, "get", side_effect=mock_get): + client = GithubClient(credentials) + issues = list(client.fetch_issues("owner", "repo")) + + assert len(issues) == 1 + assert issues[0]["number"] == 1 + + +def test_fetch_issues_with_since_filter(): + """Test fetching issues with since parameter.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.return_value = mock_response + + client = GithubClient(credentials) + since = datetime(2024, 1, 15, tzinfo=timezone.utc) + list(client.fetch_issues("owner", "repo", since=since)) + + # Verify since was passed to API + call_args = mock_get.call_args + assert "since" in call_args.kwargs.get("params", {}) + + +def test_fetch_issues_with_state_filter(): + """Test fetching issues with state filter.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.return_value = mock_response + + client = GithubClient(credentials) + list(client.fetch_issues("owner", "repo", state="closed")) + + call_args = mock_get.call_args + assert call_args.kwargs.get("params", {}).get("state") == "closed" + + +def test_fetch_issues_with_labels_filter(): + """Test fetching issues with labels filter.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.return_value = mock_response + + client = GithubClient(credentials) + list(client.fetch_issues("owner", "repo", labels=["bug", "critical"])) + + call_args = mock_get.call_args + assert call_args.kwargs.get("params", {}).get("labels") == "bug,critical" + + +# ============================================================================= +# Tests for fetch_prs +# ============================================================================= + + +def test_fetch_prs_basic(): + """Test fetching PRs from repository.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + def mock_get(url, **kwargs): + """Route mock responses based on URL.""" + response = Mock() + response.headers = {"X-RateLimit-Remaining": "4999"} + response.raise_for_status = Mock() + + page = kwargs.get("params", {}).get("page", 1) + + if "/pulls" in url and "/comments" not in url: + if page == 1: + response.json.return_value = [ + { + "number": 10, + "title": "Add feature", + "body": "PR body", + "state": "open", + "user": {"login": "contributor"}, + "labels": [{"name": "enhancement"}], + "assignees": [{"login": "reviewer"}], + "milestone": None, + "created_at": "2024-01-05T00:00:00Z", + "updated_at": "2024-01-06T00:00:00Z", + "closed_at": None, + "merged_at": None, + "diff_url": "https://github.com/owner/repo/pull/10.diff", + "comments": 0, + } + ] + else: + response.json.return_value = [] + elif ".diff" in url: + response.ok = True + response.text = "+100 lines added\n-50 lines removed" + elif "/comments" in url: + response.json.return_value = [] + else: + response.json.return_value = [] + + return response + + with patch.object(requests.Session, "get", side_effect=mock_get): + client = GithubClient(credentials) + prs = list(client.fetch_prs("owner", "repo")) + + assert len(prs) == 1 + pr = prs[0] + assert pr["number"] == 10 + assert pr["title"] == "Add feature" + assert pr["kind"] == "pr" + assert pr["diff_summary"] is not None + assert "100 lines added" in pr["diff_summary"] + + +def test_fetch_prs_merged(): + """Test fetching merged PR.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "number": 20, + "title": "Merged PR", + "body": "Body", + "state": "closed", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-10T00:00:00Z", + "closed_at": "2024-01-10T00:00:00Z", + "merged_at": "2024-01-10T00:00:00Z", + "additions": 10, + "deletions": 5, + "comments": 0, + } + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty, mock_empty] + + client = GithubClient(credentials) + prs = list(client.fetch_prs("owner", "repo")) + + pr = prs[0] + assert pr["state"] == "closed" + assert pr["merged_at"] == datetime(2024, 1, 10, 0, 0, 0, tzinfo=timezone.utc) + + +def test_fetch_prs_stops_at_since(): + """Test that PR fetching stops when reaching older items.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "number": 30, + "title": "Recent PR", + "body": "Body", + "state": "open", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-20T00:00:00Z", + "updated_at": "2024-01-20T00:00:00Z", + "closed_at": None, + "merged_at": None, + "additions": 1, + "deletions": 1, + "comments": 0, + }, + { + "number": 29, + "title": "Old PR", + "body": "Body", + "state": "open", + "user": {"login": "user"}, + "labels": [], + "assignees": [], + "milestone": None, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", # Older than since + "closed_at": None, + "merged_at": None, + "additions": 1, + "deletions": 1, + "comments": 0, + }, + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + since = datetime(2024, 1, 15, tzinfo=timezone.utc) + prs = list(client.fetch_prs("owner", "repo", since=since)) + + # Should only get the recent PR, stop at the old one + assert len(prs) == 1 + assert prs[0]["number"] == 30 + + +# ============================================================================= +# Tests for fetch_comments +# ============================================================================= + + +def test_fetch_comments_pagination(): + """Test comment fetching with pagination.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + # First page of comments + mock_page1 = Mock() + mock_page1.json.return_value = [ + { + "id": 1, + "user": {"login": "user1"}, + "body": "Comment 1", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + ] + mock_page1.headers = {"X-RateLimit-Remaining": "4999"} + mock_page1.raise_for_status = Mock() + + # Second page of comments + mock_page2 = Mock() + mock_page2.json.return_value = [ + { + "id": 2, + "user": {"login": "user2"}, + "body": "Comment 2", + "created_at": "2024-01-02T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + } + ] + mock_page2.headers = {"X-RateLimit-Remaining": "4998"} + mock_page2.raise_for_status = Mock() + + # Empty page to stop + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4997"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_page1, mock_page2, mock_empty] + + client = GithubClient(credentials) + comments = client.fetch_comments("owner", "repo", 1) + + assert len(comments) == 2 + assert comments[0]["author"] == "user1" + assert comments[1]["author"] == "user2" + + +def test_fetch_comments_handles_ghost_user(): + """Test comment with deleted/ghost user.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [ + { + "id": 1, + "user": None, # Deleted user + "body": "Comment from ghost", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + ] + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + mock_empty = Mock() + mock_empty.json.return_value = [] + mock_empty.headers = {"X-RateLimit-Remaining": "4998"} + mock_empty.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.side_effect = [mock_response, mock_empty] + + client = GithubClient(credentials) + comments = client.fetch_comments("owner", "repo", 1) + + assert len(comments) == 1 + assert comments[0]["author"] == "ghost" + + +# ============================================================================= +# Tests for rate limiting +# ============================================================================= + + +def test_rate_limit_handling(): + """Test rate limit detection and backoff.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = [] + mock_response.headers = { + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(datetime.now(timezone.utc).timestamp()) + 1), + } + mock_response.raise_for_status = Mock() + + with patch.object(requests.Session, "get") as mock_get: + mock_get.return_value = mock_response + with patch("time.sleep") as mock_sleep: + client = GithubClient(credentials) + list(client.fetch_issues("owner", "repo")) + + # Should have waited due to rate limit + mock_sleep.assert_called() + + +# ============================================================================= +# Tests for project fields +# ============================================================================= + + +def test_fetch_project_fields(): + """Test fetching GitHub Projects v2 fields.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = { + "data": { + "repository": { + "issue": { + "projectItems": { + "nodes": [ + { + "project": {"title": "Sprint Board"}, + "fieldValues": { + "nodes": [ + {"field": {"name": "Status"}, "name": "In Progress"}, + {"field": {"name": "Priority"}, "text": "High"}, + ] + }, + } + ] + } + } + } + } + } + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + with patch.object(requests.Session, "post") as mock_post: + mock_post.return_value = mock_response + + client = GithubClient(credentials) + fields = client.fetch_project_fields("owner", "repo", 1) + + assert fields is not None + # Fields are prefixed with project name + assert "Sprint Board.Status" in fields + assert fields["Sprint Board.Status"] == "In Progress" + assert "Sprint Board.Priority" in fields + assert fields["Sprint Board.Priority"] == "High" + + +def test_fetch_project_fields_not_in_project(): + """Test fetching project fields for issue not in any project.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = { + "data": {"repository": {"issue": {"projectItems": {"nodes": []}}}} + } + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + with patch.object(requests.Session, "post") as mock_post: + mock_post.return_value = mock_response + + client = GithubClient(credentials) + fields = client.fetch_project_fields("owner", "repo", 1) + + assert fields is None + + +def test_fetch_project_fields_graphql_error(): + """Test handling GraphQL errors gracefully.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.json.return_value = { + "errors": [{"message": "Something went wrong"}], + "data": None, + } + mock_response.headers = {"X-RateLimit-Remaining": "4999"} + mock_response.raise_for_status = Mock() + + with patch.object(requests.Session, "post") as mock_post: + mock_post.return_value = mock_response + + client = GithubClient(credentials) + fields = client.fetch_project_fields("owner", "repo", 1) + + assert fields is None + + +# ============================================================================= +# Tests for error handling +# ============================================================================= + + +def test_fetch_issues_handles_api_error(): + """Test graceful handling of API errors.""" + credentials = GithubCredentials(auth_type="pat", access_token="token") + + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") + + with patch.object(requests.Session, "get") as mock_get: + mock_get.return_value = mock_response + + client = GithubClient(credentials) + + with pytest.raises(requests.HTTPError): + list(client.fetch_issues("owner", "nonexistent")) diff --git a/tests/memory/workers/tasks/test_github_tasks.py b/tests/memory/workers/tasks/test_github_tasks.py new file mode 100644 index 0000000..571cf1d --- /dev/null +++ b/tests/memory/workers/tasks/test_github_tasks.py @@ -0,0 +1,1125 @@ +"""Tests for GitHub issue/PR syncing tasks.""" + +import pytest +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, patch + +from memory.common.db.models import GithubItem +from memory.common.db.models.sources import GithubAccount, GithubRepo +from memory.workers.tasks import github +from memory.workers.tasks.github import ( + _build_content, + _needs_reindex, + _serialize_issue_data, + _deserialize_issue_data, +) +from memory.parsers.github import GithubIssueData, GithubComment +from memory.common.db import connection as db_connection + + +@pytest.fixture(autouse=True) +def reset_db_cache(): + """Reset the cached database engine between tests. + + The db connection module caches the engine globally, which can cause + issues when test databases are created/dropped between tests. + """ + # Reset before test + db_connection._engine = None + db_connection._session_factory = None + db_connection._scoped_session = None + yield + # Reset after test + db_connection._engine = None + db_connection._session_factory = None + db_connection._scoped_session = None + + +@pytest.fixture +def mock_github_comment() -> GithubComment: + """Mock comment data.""" + return GithubComment( + id=1001, + author="commenter", + body="This is a comment on the issue.", + created_at="2024-01-01T12:30:00Z", + updated_at="2024-01-01T12:30:00Z", + ) + + +@pytest.fixture +def mock_issue_data(mock_github_comment) -> GithubIssueData: + """Mock issue data for testing.""" + return GithubIssueData( + kind="issue", + number=42, + title="Test Issue Title", + body="This is the issue body with some content to test.", + state="open", + author="testuser", + labels=["bug", "help wanted"], + assignees=["developer1"], + milestone="v1.0", + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc), + comment_count=1, + comments=[mock_github_comment], + diff_summary=None, + project_fields=None, + content_hash="abc123hash", + ) + + +@pytest.fixture +def mock_pr_data() -> GithubIssueData: + """Mock PR data for testing.""" + return GithubIssueData( + kind="pr", + number=123, + title="Add new feature", + body="This PR adds a new feature to the project.", + state="open", + author="contributor", + labels=["enhancement"], + assignees=["reviewer1", "reviewer2"], + milestone="v2.0", + created_at=datetime(2024, 1, 5, 9, 0, 0, tzinfo=timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 6, 14, 0, 0, tzinfo=timezone.utc), + comment_count=0, + comments=[], + diff_summary="+100 -50", + project_fields={"Status": "In Progress", "Priority": "High"}, + content_hash="pr123hash", + ) + + +@pytest.fixture +def mock_closed_issue_data() -> GithubIssueData: + """Mock closed issue data.""" + return GithubIssueData( + kind="issue", + number=10, + title="Fixed Bug", + body="This bug has been fixed.", + state="closed", + author="reporter", + labels=["bug", "fixed"], + assignees=[], + milestone=None, + created_at=datetime(2023, 12, 1, 12, 0, 0, tzinfo=timezone.utc), + closed_at=datetime(2023, 12, 15, 18, 0, 0, tzinfo=timezone.utc), + merged_at=None, + github_updated_at=datetime(2023, 12, 15, 18, 0, 0, tzinfo=timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="closedhash", + ) + + +@pytest.fixture +def github_account(db_session) -> GithubAccount: + """Create a GitHub account for testing.""" + account = GithubAccount( + name="Test Account", + auth_type="pat", + access_token="ghp_test_token_12345", + active=True, + ) + db_session.add(account) + db_session.commit() + return account + + +@pytest.fixture +def inactive_github_account(db_session) -> GithubAccount: + """Create an inactive GitHub account.""" + account = GithubAccount( + name="Inactive Account", + auth_type="pat", + access_token="ghp_inactive_token", + active=False, + ) + db_session.add(account) + db_session.commit() + return account + + +@pytest.fixture +def github_repo(db_session, github_account) -> GithubRepo: + """Create a GitHub repo for testing.""" + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="testrepo", + track_issues=True, + track_prs=True, + track_comments=True, + track_project_fields=False, + labels_filter=[], + state_filter=None, + tags=["github", "test"], + check_interval=60, + full_sync_interval=1440, + active=True, + last_sync_at=None, + last_full_sync_at=None, + ) + db_session.add(repo) + db_session.commit() + return repo + + +@pytest.fixture +def inactive_github_repo(db_session, github_account) -> GithubRepo: + """Create an inactive GitHub repo.""" + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="inactiverepo", + track_issues=True, + track_prs=True, + active=False, + ) + db_session.add(repo) + db_session.commit() + return repo + + +@pytest.fixture +def github_repo_with_project_fields(db_session, github_account) -> GithubRepo: + """Create a GitHub repo with project field tracking enabled.""" + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="projectrepo", + track_issues=True, + track_prs=True, + track_comments=True, + track_project_fields=True, + labels_filter=[], + state_filter=None, + tags=["project"], + check_interval=60, + full_sync_interval=1440, + active=True, + last_sync_at=None, + last_full_sync_at=None, + ) + db_session.add(repo) + db_session.commit() + return repo + + +@pytest.fixture +def mock_github_client(): + """Mock GitHub client for testing.""" + client = Mock() + client.fetch_issues.return_value = iter([]) + client.fetch_prs.return_value = iter([]) + client.fetch_project_fields.return_value = None + client.fetch_pr_project_fields.return_value = None + return client + + +# ============================================================================= +# Tests for helper functions +# ============================================================================= + + +def test_build_content_basic(mock_issue_data): + """Test content building from issue data.""" + content = _build_content(mock_issue_data) + + assert "# Test Issue Title" in content + assert "This is the issue body with some content to test." in content + assert "**commenter**: This is a comment on the issue." in content + + +def test_build_content_no_comments(): + """Test content building with no comments.""" + data = GithubIssueData( + kind="issue", + number=1, + title="Simple Issue", + body="Body text", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="hash", + ) + content = _build_content(data) + + assert "# Simple Issue" in content + assert "Body text" in content + assert "---" not in content # No comment separator + + +def test_serialize_deserialize_issue_data(mock_issue_data): + """Test serialization and deserialization roundtrip.""" + serialized = _serialize_issue_data(mock_issue_data) + deserialized = _deserialize_issue_data(serialized) + + assert deserialized["kind"] == mock_issue_data["kind"] + assert deserialized["number"] == mock_issue_data["number"] + assert deserialized["title"] == mock_issue_data["title"] + assert deserialized["body"] == mock_issue_data["body"] + assert deserialized["state"] == mock_issue_data["state"] + assert deserialized["author"] == mock_issue_data["author"] + assert deserialized["labels"] == mock_issue_data["labels"] + assert deserialized["created_at"] == mock_issue_data["created_at"] + assert deserialized["github_updated_at"] == mock_issue_data["github_updated_at"] + + +def test_serialize_handles_none_dates(): + """Test serialization handles None dates correctly.""" + data = GithubIssueData( + kind="issue", + number=1, + title="Test", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="hash", + ) + serialized = _serialize_issue_data(data) + + assert serialized["closed_at"] is None + assert serialized["merged_at"] is None + + +# ============================================================================= +# Tests for _needs_reindex +# ============================================================================= + + +def test_needs_reindex_content_hash_changed(github_repo, db_session): + """Test reindex triggered by content hash change.""" + existing = GithubItem( + repo_path="testorg/testrepo", + repo_id=github_repo.id, + number=42, + kind="issue", + title="Old Title", + content_hash="oldhash", + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + project_fields=None, + modality="text", + sha256=b"x" * 32, + ) + db_session.add(existing) + db_session.commit() + + new_data = GithubIssueData( + kind="issue", + number=42, + title="New Title", + body="New body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="newhash", # Different hash + ) + + assert _needs_reindex(existing, new_data) is True + + +def test_needs_reindex_github_updated_at_newer(github_repo, db_session): + """Test reindex triggered by newer github_updated_at.""" + existing = GithubItem( + repo_path="testorg/testrepo", + repo_id=github_repo.id, + number=42, + kind="issue", + title="Title", + content_hash="samehash", + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + project_fields=None, + modality="text", + sha256=b"x" * 32, + ) + db_session.add(existing) + db_session.commit() + + new_data = GithubIssueData( + kind="issue", + number=42, + title="Title", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 2, tzinfo=timezone.utc), # Newer + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="samehash", + ) + + assert _needs_reindex(existing, new_data) is True + + +def test_needs_reindex_project_fields_changed(github_repo, db_session): + """Test reindex triggered by project field changes.""" + existing = GithubItem( + repo_path="testorg/testrepo", + repo_id=github_repo.id, + number=42, + kind="issue", + title="Title", + content_hash="samehash", + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + project_fields={"Status": "Todo"}, + modality="text", + sha256=b"x" * 32, + ) + db_session.add(existing) + db_session.commit() + + new_data = GithubIssueData( + kind="issue", + number=42, + title="Title", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields={"Status": "In Progress"}, # Changed + content_hash="samehash", + ) + + assert _needs_reindex(existing, new_data) is True + + +def test_needs_reindex_no_changes(github_repo, db_session): + """Test no reindex when nothing changed.""" + existing = GithubItem( + repo_path="testorg/testrepo", + repo_id=github_repo.id, + number=42, + kind="issue", + title="Title", + content_hash="samehash", + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + project_fields={"Status": "Todo"}, + modality="text", + sha256=b"x" * 32, + ) + db_session.add(existing) + db_session.commit() + + new_data = GithubIssueData( + kind="issue", + number=42, + title="Title", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), # Same + comment_count=0, + comments=[], + diff_summary=None, + project_fields={"Status": "Todo"}, # Same + content_hash="samehash", # Same + ) + + assert _needs_reindex(existing, new_data) is False + + +# ============================================================================= +# Tests for sync_github_item +# ============================================================================= + + +def test_sync_github_item_new_issue(mock_issue_data, github_repo, db_session, qdrant): + """Test syncing a new GitHub issue.""" + serialized = _serialize_issue_data(mock_issue_data) + + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "processed" + + # Verify item was created + item = ( + db_session.query(GithubItem) + .filter_by(repo_path="testorg/testrepo", number=42) + .first() + ) + assert item is not None + assert item.title == "Test Issue Title" + assert item.kind == "issue" + assert item.state == "open" + assert item.author == "testuser" + assert "bug" in item.labels + assert "github" in item.tags # From repo tags + assert "bug" in item.tags # From issue labels + + +def test_sync_github_item_new_pr(mock_pr_data, github_repo, db_session, qdrant): + """Test syncing a new GitHub PR.""" + serialized = _serialize_issue_data(mock_pr_data) + + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "processed" + + # Verify item was created + item = ( + db_session.query(GithubItem) + .filter_by(repo_path="testorg/testrepo", number=123, kind="pr") + .first() + ) + assert item is not None + assert item.title == "Add new feature" + assert item.kind == "pr" + assert item.diff_summary == "+100 -50" + assert item.project_status == "In Progress" + assert item.project_priority == "High" + + +def test_sync_github_item_repo_not_found(mock_issue_data, db_session): + """Test syncing with non-existent repo.""" + serialized = _serialize_issue_data(mock_issue_data) + + result = github.sync_github_item(99999, serialized) + + assert result["status"] == "error" + assert "Repo not found" in result["error"] + + +def test_sync_github_item_existing_unchanged( + mock_issue_data, github_repo, db_session, qdrant +): + """Test syncing existing item with no changes.""" + # Create existing item + serialized = _serialize_issue_data(mock_issue_data) + github.sync_github_item(github_repo.id, serialized) + + # Sync again with same data + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "unchanged" + + +def test_sync_github_item_existing_updated(github_repo, db_session, qdrant): + """Test syncing existing item with content changes.""" + from memory.workers.tasks.content_processing import create_content_hash + + # Create existing item directly in the test database + existing_item = GithubItem( + repo_path="testorg/testrepo", + repo_id=github_repo.id, + number=99, + kind="issue", + title="Original Title", + content="# Original Title\n\nOriginal body", + state="open", + author="user", + labels=["bug"], + assignees=[], + milestone=None, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + github_updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + comment_count=0, + content_hash="originalhash", + modality="text", + mime_type="text/markdown", + sha256=create_content_hash("# Original Title\n\nOriginal body"), + size=100, + tags=["github", "test", "bug"], + ) + db_session.add(existing_item) + db_session.commit() + + # Update with new content + updated_data = GithubIssueData( + kind="issue", + number=99, + title="Updated Title", + body="Updated body with more content", + state="open", + author="user", + labels=["bug", "fixed"], + assignees=["dev1"], + milestone=None, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime(2024, 1, 5, tzinfo=timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="updatedhash", # Different hash triggers reindex + ) + serialized = _serialize_issue_data(updated_data) + result = github.sync_github_item(github_repo.id, serialized) + + assert result["status"] == "processed" + + # Verify item was updated - query fresh from DB + db_session.expire_all() + item = ( + db_session.query(GithubItem) + .filter_by(repo_path="testorg/testrepo", number=99) + .first() + ) + assert item.title == "Updated Title" + assert "fixed" in item.labels + assert "dev1" in item.assignees + + +# ============================================================================= +# Tests for sync_github_repo +# ============================================================================= + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_success( + mock_client_class, mock_issue_data, github_repo, db_session +): + """Test successful repo sync.""" + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([mock_issue_data]) + mock_client.fetch_prs.return_value = iter([]) + mock_client_class.return_value = mock_client + + with patch("memory.workers.tasks.github.sync_github_item") as mock_sync_item: + mock_sync_item.delay.return_value = Mock(id="task-123") + + result = github.sync_github_repo(github_repo.id) + + assert result["status"] == "completed" + assert result["sync_type"] == "incremental" + assert result["repo_path"] == "testorg/testrepo" + assert result["issues_synced"] == 1 + assert result["prs_synced"] == 0 + assert result["task_ids"] == ["task-123"] + + # Verify sync_github_item was called + mock_sync_item.delay.assert_called_once() + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_with_prs( + mock_client_class, mock_issue_data, mock_pr_data, github_repo, db_session +): + """Test repo sync with both issues and PRs.""" + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([mock_issue_data]) + mock_client.fetch_prs.return_value = iter([mock_pr_data]) + mock_client_class.return_value = mock_client + + with patch("memory.workers.tasks.github.sync_github_item") as mock_sync_item: + mock_sync_item.delay.side_effect = [Mock(id="task-1"), Mock(id="task-2")] + + result = github.sync_github_repo(github_repo.id) + + assert result["issues_synced"] == 1 + assert result["prs_synced"] == 1 + assert len(result["task_ids"]) == 2 + + +def test_sync_github_repo_not_found(db_session): + """Test sync with non-existent repo.""" + result = github.sync_github_repo(99999) + + assert result["status"] == "error" + assert "Repo not found or inactive" in result["error"] + + +def test_sync_github_repo_inactive(inactive_github_repo, db_session): + """Test sync with inactive repo.""" + result = github.sync_github_repo(inactive_github_repo.id) + + assert result["status"] == "error" + assert "Repo not found or inactive" in result["error"] + + +def test_sync_github_repo_inactive_account(db_session, inactive_github_account): + """Test sync with inactive account.""" + repo = GithubRepo( + account_id=inactive_github_account.id, + owner="testorg", + name="repo", + active=True, + ) + db_session.add(repo) + db_session.commit() + + result = github.sync_github_repo(repo.id) + + assert result["status"] == "error" + assert "Account not found or inactive" in result["error"] + + +@pytest.mark.parametrize( + "check_interval_minutes,seconds_since_check,should_skip", + [ + (60, 30, True), # 60min interval, checked 30s ago -> skip + (60, 3000, True), # 60min interval, checked 50min ago -> skip + (60, 4000, False), # 60min interval, checked 66min ago -> don't skip + (30, 1000, True), # 30min interval, checked 16min ago -> skip + (30, 2000, False), # 30min interval, checked 33min ago -> don't skip + ], +) +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_check_interval( + mock_client_class, + check_interval_minutes, + seconds_since_check, + should_skip, + github_account, + db_session, +): + """Test sync respects check interval.""" + from sqlalchemy import text + + # Setup mock client for non-skipped cases + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([]) + mock_client.fetch_prs.return_value = iter([]) + mock_client_class.return_value = mock_client + + # Create repo with specific check interval + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="intervalrepo", + track_issues=True, + track_prs=True, + check_interval=check_interval_minutes, + active=True, + ) + db_session.add(repo) + db_session.flush() + + # Set last_sync_at + last_sync_time = datetime.now(timezone.utc) - timedelta(seconds=seconds_since_check) + db_session.execute( + text("UPDATE github_repos SET last_sync_at = :timestamp WHERE id = :repo_id"), + {"timestamp": last_sync_time, "repo_id": repo.id}, + ) + db_session.commit() + + result = github.sync_github_repo(repo.id) + + if should_skip: + assert result["status"] == "skipped_recent_check" + mock_client_class.assert_not_called() + else: + assert result["status"] == "completed" + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_force_full(mock_client_class, github_repo, db_session): + """Test force_full bypasses check interval.""" + from sqlalchemy import text + + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([]) + mock_client.fetch_prs.return_value = iter([]) + mock_client_class.return_value = mock_client + + # Set recent last_sync_at + last_sync_time = datetime.now(timezone.utc) - timedelta(seconds=30) + db_session.execute( + text("UPDATE github_repos SET last_sync_at = :timestamp WHERE id = :repo_id"), + {"timestamp": last_sync_time, "repo_id": github_repo.id}, + ) + db_session.commit() + + result = github.sync_github_repo(github_repo.id, force_full=True) + + assert result["status"] == "completed" + assert result["sync_type"] == "full" + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_full_sync_for_project_fields( + mock_client_class, github_repo_with_project_fields, db_session +): + """Test full sync triggered for project fields when never synced before.""" + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([]) + mock_client.fetch_prs.return_value = iter([]) + mock_client.fetch_project_fields.return_value = None + mock_client.fetch_pr_project_fields.return_value = None + mock_client_class.return_value = mock_client + + result = github.sync_github_repo(github_repo_with_project_fields.id) + + assert result["status"] == "completed" + assert result["sync_type"] == "full" + + # Verify fetch_issues was called with state="open" for full sync + mock_client.fetch_issues.assert_called_once() + call_args = mock_client.fetch_issues.call_args + assert call_args[0][3] == "open" # state argument + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_updates_timestamps(mock_client_class, github_repo, db_session): + """Test that sync updates last_sync_at timestamp.""" + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([]) + mock_client.fetch_prs.return_value = iter([]) + mock_client_class.return_value = mock_client + + assert github_repo.last_sync_at is None + + github.sync_github_repo(github_repo.id) + + db_session.refresh(github_repo) + assert github_repo.last_sync_at is not None + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_with_labels_filter( + mock_client_class, github_account, db_session +): + """Test sync passes labels filter to client.""" + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([]) + mock_client.fetch_prs.return_value = iter([]) + mock_client_class.return_value = mock_client + + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="filtered", + labels_filter=["bug", "critical"], + track_issues=True, + track_prs=False, + active=True, + ) + db_session.add(repo) + db_session.commit() + + github.sync_github_repo(repo.id) + + # Verify labels filter was passed + mock_client.fetch_issues.assert_called_once() + call_args = mock_client.fetch_issues.call_args + assert call_args[0][4] == ["bug", "critical"] # labels argument + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_issues_only(mock_client_class, github_account, db_session): + """Test sync with only issues tracking enabled.""" + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([]) + mock_client.fetch_prs.return_value = iter([]) + mock_client_class.return_value = mock_client + + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="issuesonly", + track_issues=True, + track_prs=False, + active=True, + ) + db_session.add(repo) + db_session.commit() + + github.sync_github_repo(repo.id) + + mock_client.fetch_issues.assert_called_once() + mock_client.fetch_prs.assert_not_called() + + +@patch("memory.workers.tasks.github.GithubClient") +def test_sync_github_repo_prs_only(mock_client_class, github_account, db_session): + """Test sync with only PRs tracking enabled.""" + mock_client = Mock() + mock_client.fetch_issues.return_value = iter([]) + mock_client.fetch_prs.return_value = iter([]) + mock_client_class.return_value = mock_client + + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="prsonly", + track_issues=False, + track_prs=True, + active=True, + ) + db_session.add(repo) + db_session.commit() + + github.sync_github_repo(repo.id) + + mock_client.fetch_issues.assert_not_called() + mock_client.fetch_prs.assert_called_once() + + +# ============================================================================= +# Tests for sync_all_github_repos +# ============================================================================= + + +@patch("memory.workers.tasks.github.sync_github_repo") +def test_sync_all_github_repos(mock_sync_repo, db_session): + """Test syncing all active repos.""" + # Create accounts and repos + account1 = GithubAccount( + name="Account 1", auth_type="pat", access_token="token1", active=True + ) + account2 = GithubAccount( + name="Account 2", auth_type="pat", access_token="token2", active=True + ) + db_session.add_all([account1, account2]) + db_session.flush() + + repo1 = GithubRepo( + account_id=account1.id, owner="org1", name="repo1", active=True + ) + repo2 = GithubRepo( + account_id=account2.id, owner="org2", name="repo2", active=True + ) + repo3 = GithubRepo( + account_id=account1.id, owner="org1", name="inactive", active=False + ) + db_session.add_all([repo1, repo2, repo3]) + db_session.commit() + + mock_sync_repo.delay.side_effect = [Mock(id="task-1"), Mock(id="task-2")] + + result = github.sync_all_github_repos() + + assert len(result) == 2 # Only active repos + assert result[0]["repo_path"] == "org1/repo1" + assert result[0]["task_id"] == "task-1" + assert result[1]["repo_path"] == "org2/repo2" + assert result[1]["task_id"] == "task-2" + + +@patch("memory.workers.tasks.github.sync_github_repo") +def test_sync_all_github_repos_inactive_account(mock_sync_repo, db_session): + """Test that repos with inactive accounts are not synced.""" + active_account = GithubAccount( + name="Active", auth_type="pat", access_token="token", active=True + ) + inactive_account = GithubAccount( + name="Inactive", auth_type="pat", access_token="token", active=False + ) + db_session.add_all([active_account, inactive_account]) + db_session.flush() + + repo1 = GithubRepo( + account_id=active_account.id, owner="org", name="repo1", active=True + ) + repo2 = GithubRepo( + account_id=inactive_account.id, owner="org", name="repo2", active=True + ) + db_session.add_all([repo1, repo2]) + db_session.commit() + + mock_sync_repo.delay.return_value = Mock(id="task-1") + + result = github.sync_all_github_repos() + + assert len(result) == 1 + assert result[0]["repo_path"] == "org/repo1" + + +def test_sync_all_github_repos_no_active_repos(db_session): + """Test sync_all when no active repos exist.""" + # Create only inactive repo + account = GithubAccount( + name="Account", auth_type="pat", access_token="token", active=True + ) + db_session.add(account) + db_session.flush() + + inactive_repo = GithubRepo( + account_id=account.id, owner="org", name="inactive", active=False + ) + db_session.add(inactive_repo) + db_session.commit() + + result = github.sync_all_github_repos() + + assert result == [] + + +# ============================================================================= +# Tests for project field extraction +# ============================================================================= + + +def test_project_status_extraction(github_repo, db_session, qdrant): + """Test project status is extracted from project_fields.""" + data = GithubIssueData( + kind="issue", + number=50, + title="Project Issue", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields={"Status": "Done", "Priority": "Low", "Custom Field": "Value"}, + content_hash="hash", + ) + serialized = _serialize_issue_data(data) + github.sync_github_item(github_repo.id, serialized) + + item = db_session.query(GithubItem).filter_by(number=50).first() + assert item.project_status == "Done" + assert item.project_priority == "Low" + assert item.project_fields == { + "Status": "Done", + "Priority": "Low", + "Custom Field": "Value", + } + + +def test_project_fields_case_insensitive(github_repo, db_session, qdrant): + """Test project field extraction is case insensitive.""" + data = GithubIssueData( + kind="issue", + number=51, + title="Case Test", + body="Body", + state="open", + author="user", + labels=[], + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields={"PROJECT STATUS": "In Review", "item priority": "Medium"}, + content_hash="hash", + ) + serialized = _serialize_issue_data(data) + github.sync_github_item(github_repo.id, serialized) + + item = db_session.query(GithubItem).filter_by(number=51).first() + assert item.project_status == "In Review" + assert item.project_priority == "Medium" + + +# ============================================================================= +# Tests for tag merging +# ============================================================================= + + +@pytest.mark.parametrize( + "repo_tags,issue_labels,expected_tags", + [ + (["github"], ["bug"], ["github", "bug"]), + (["tag1", "tag2"], ["label1", "label2"], ["tag1", "tag2", "label1", "label2"]), + ([], ["bug"], ["bug"]), + (["github"], [], ["github"]), + ([], [], []), + ], +) +def test_tag_merging(repo_tags, issue_labels, expected_tags, github_account, db_session, qdrant): + """Test tags are merged from repo and issue labels.""" + repo = GithubRepo( + account_id=github_account.id, + owner="testorg", + name="tagrepo", + tags=repo_tags, + active=True, + ) + db_session.add(repo) + db_session.commit() + + data = GithubIssueData( + kind="issue", + number=60, + title="Tag Test", + body="Body", + state="open", + author="user", + labels=issue_labels, + assignees=[], + milestone=None, + created_at=datetime.now(timezone.utc), + closed_at=None, + merged_at=None, + github_updated_at=datetime.now(timezone.utc), + comment_count=0, + comments=[], + diff_summary=None, + project_fields=None, + content_hash="hash", + ) + serialized = _serialize_issue_data(data) + github.sync_github_item(repo.id, serialized) + + item = db_session.query(GithubItem).filter_by(number=60).first() + assert item.tags == expected_tags diff --git a/tools/run_celery_task.py b/tools/run_celery_task.py index 9ca9fa0..fcbf4f5 100644 --- a/tools/run_celery_task.py +++ b/tools/run_celery_task.py @@ -15,6 +15,8 @@ Usage: python run_celery_task.py blogs sync-webpage --url "https://example.com" python run_celery_task.py comic sync-all-comics python run_celery_task.py forums sync-lesswrong --since-date "2025-01-01" --min-karma 10 --limit 50 --cooldown 0.5 --max-items 1000 + python run_celery_task.py github sync-all-repos + python run_celery_task.py github sync-repo --repo-id 1 --force-full """ import json @@ -51,8 +53,10 @@ from memory.common.celery_app import ( UPDATE_METADATA_FOR_SOURCE_ITEMS, SETUP_GIT_NOTES, TRACK_GIT_CHANGES, - BACKUP_TO_S3_DIRECTORY, + BACKUP_PATH, BACKUP_ALL, + SYNC_GITHUB_REPO, + SYNC_ALL_GITHUB_REPOS, app, ) @@ -100,9 +104,13 @@ TASK_MAPPINGS = { "track_git_changes": TRACK_GIT_CHANGES, }, "backup": { - "backup_to_s3_directory": BACKUP_TO_S3_DIRECTORY, + "backup_path": BACKUP_PATH, "backup_all": BACKUP_ALL, }, + "github": { + "sync_all_repos": SYNC_ALL_GITHUB_REPOS, + "sync_repo": SYNC_GITHUB_REPO, + }, } QUEUE_MAPPINGS = { "email": "email", @@ -200,9 +208,9 @@ def backup_all(ctx): @backup.command("path") @click.option("--path", required=True, help="Path to backup") @click.pass_context -def backup_to_s3_directory(ctx, path): +def backup_path_cmd(ctx, path): """Backup a specific path.""" - execute_task(ctx, "backup", "backup_to_s3_directory", path=path) + execute_task(ctx, "backup", "backup_path", path=path) @cli.group() @@ -533,5 +541,28 @@ def forums_sync_lesswrong_post(ctx, url): execute_task(ctx, "forums", "sync_lesswrong_post", url=url) +@cli.group() +@click.pass_context +def github(ctx): + """GitHub-related tasks.""" + pass + + +@github.command("sync-all-repos") +@click.pass_context +def github_sync_all_repos(ctx): + """Sync all active GitHub repos.""" + execute_task(ctx, "github", "sync_all_repos") + + +@github.command("sync-repo") +@click.option("--repo-id", type=int, required=True, help="GitHub repo ID") +@click.option("--force-full", is_flag=True, help="Force a full sync instead of incremental") +@click.pass_context +def github_sync_repo(ctx, repo_id, force_full): + """Sync a specific GitHub repo.""" + execute_task(ctx, "github", "sync_repo", repo_id=repo_id, force_full=force_full) + + if __name__ == "__main__": cli()