diff --git a/functions-python/update_validation_report/src/__init__.py b/functions-python/helpers/task_execution/__init__.py similarity index 100% rename from functions-python/update_validation_report/src/__init__.py rename to functions-python/helpers/task_execution/__init__.py diff --git a/functions-python/helpers/task_execution/task_execution_tracker.py b/functions-python/helpers/task_execution/task_execution_tracker.py new file mode 100644 index 000000000..45912a6b6 --- /dev/null +++ b/functions-python/helpers/task_execution/task_execution_tracker.py @@ -0,0 +1,458 @@ +# +# MobilityData 2026 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Generic task execution tracker backed by the task_run and task_execution_log DB tables. + +Mirrors the DatasetTraceService / BatchExecutionService pattern (currently Datastore-based) +so that batch_process_dataset, batch_datasets, and gbfs_validator can migrate to this +class in the future. + +Usage: + tracker = TaskExecutionTracker( + task_name="gtfs_validation", + run_id="7.0.0", + db_session=session, + ) + tracker.start_run(total_count=5000, params={"validator_endpoint": "...", "env": "staging"}) + + if not tracker.is_triggered(dataset_id): + execute_workflow(...) + tracker.mark_triggered(dataset_id, execution_ref=execution.name) + + # Later, in process_validation_report: + tracker.mark_completed(dataset_id) + + summary = tracker.get_summary() + # {"total_count": 5000, "triggered": 150, "completed": 140, "failed": 2, "pending": 4858, ...} +""" + +import json +import logging +import os +import re +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any, Optional + +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session + +from shared.database_gen.sqlacodegen_models import TaskExecutionLog, TaskRun + +STATUS_IN_PROGRESS = "in_progress" +STATUS_COMPLETED = "completed" +STATUS_FAILED = "failed" +STATUS_TRIGGERED = "triggered" + + +class TaskInProgressError(Exception): + """ + Raised by task handlers to signal that the task run is not yet complete. + tasks_executor maps this to HTTP 503, which causes Cloud Tasks to retry + according to the queue's retry_config (typically every 10 minutes). + """ + + +class TaskExecutionTracker: + """ + Tracks execution state for a named task run across restarts and partial executions. + + Two-level model: + - task_run: one record per logical run (identified by task_name + run_id) + - task_execution_log: one record per entity processed within the run + + entity_id may be None for tasks that do not operate on specific entities + (e.g. refresh_materialized_view). + """ + + def __init__(self, task_name: str, run_id: str, db_session: Session): + self.task_name = task_name + self.run_id = run_id + self.db_session = db_session + self.task_run_id: Optional[uuid.UUID] = None + + # ------------------------------------------------------------------ + # Run-level operations + # ------------------------------------------------------------------ + + def start_run( + self, total_count: Optional[int] = None, params: Optional[dict] = None + ) -> uuid.UUID: + """ + Upsert a task_run record and return its UUID. + + Safe to call multiple times for the same (task_name, run_id) — subsequent calls + update total_count and params but preserve created_at and the existing status + unless it is still in_progress. + """ + stmt = ( + insert(TaskRun) + .values( + task_name=self.task_name, + run_id=self.run_id, + status=STATUS_IN_PROGRESS, + total_count=total_count, + params=params, + ) + .on_conflict_do_update( + constraint="task_run_task_name_run_id_key", + set_={ + "total_count": total_count, + "params": params, + "status": STATUS_IN_PROGRESS, + "completed_at": None, + }, + ) + .returning(TaskRun.id) + ) + result = self.db_session.execute(stmt) + self.db_session.flush() + self.task_run_id = result.scalar_one() + logging.info( + "TaskExecutionTracker: run %s/%s started (id=%s, total=%s)", + self.task_name, + self.run_id, + self.task_run_id, + total_count, + ) + return self.task_run_id + + def finish_run(self, status: str = STATUS_COMPLETED) -> None: + """Mark the task_run as completed or failed.""" + self.db_session.query(TaskRun).filter( + TaskRun.task_name == self.task_name, + TaskRun.run_id == self.run_id, + ).update( + {"status": status, "completed_at": datetime.now(timezone.utc)}, + synchronize_session=False, + ) + logging.info( + "TaskExecutionTracker: run %s/%s finished with status=%s", + self.task_name, + self.run_id, + status, + ) + + # ------------------------------------------------------------------ + # Entity-level operations + # ------------------------------------------------------------------ + + def count_already_tracked(self, entity_ids: list[str]) -> int: + """ + Return how many of the given entity_ids are already tracked for this run + (status triggered or completed). Useful in dry-run to preview skips. + """ + if not entity_ids: + return 0 + return ( + self.db_session.query(TaskExecutionLog) + .filter( + TaskExecutionLog.task_name == self.task_name, + TaskExecutionLog.run_id == self.run_id, + TaskExecutionLog.status.in_([STATUS_TRIGGERED, STATUS_COMPLETED]), + TaskExecutionLog.entity_id.in_(entity_ids), + ) + .count() + ) + + def is_triggered(self, entity_id: Optional[str]) -> bool: + """ + Return True if an execution log entry already exists for this entity + with status triggered or completed (i.e. should not be re-triggered). + """ + query = self.db_session.query(TaskExecutionLog).filter( + TaskExecutionLog.task_name == self.task_name, + TaskExecutionLog.run_id == self.run_id, + TaskExecutionLog.status.in_([STATUS_TRIGGERED, STATUS_COMPLETED]), + ) + if entity_id is None: + query = query.filter(TaskExecutionLog.entity_id.is_(None)) + else: + query = query.filter(TaskExecutionLog.entity_id == entity_id) + return query.first() is not None + + def mark_triggered( + self, + entity_id: Optional[str], + execution_ref: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> None: + """ + Insert a task_execution_log row with status=triggered. + Idempotent: if a row already exists for this (task_name, entity_id, run_id), + it updates execution_ref and metadata. + """ + task_run_id = self._resolve_task_run_id() + stmt = ( + insert(TaskExecutionLog) + .values( + task_run_id=task_run_id, + task_name=self.task_name, + entity_id=entity_id, + run_id=self.run_id, + status=STATUS_TRIGGERED, + execution_ref=execution_ref, + metadata_=metadata, + ) + .on_conflict_do_update( + constraint="task_execution_log_task_name_entity_id_run_id_key", + set_={ + "execution_ref": execution_ref, + "metadata": metadata, + "status": STATUS_TRIGGERED, + }, + ) + ) + self.db_session.execute(stmt) + self.db_session.flush() + logging.debug( + "TaskExecutionTracker: marked triggered entity=%s run=%s/%s ref=%s", + entity_id, + self.task_name, + self.run_id, + execution_ref, + ) + + def mark_completed(self, entity_id: Optional[str]) -> None: + """Mark an entity execution as completed.""" + self._update_entity_status(entity_id, STATUS_COMPLETED) + + def mark_failed( + self, entity_id: Optional[str], error_message: Optional[str] = None + ) -> None: + """Mark an entity execution as failed, optionally storing an error message.""" + query = self.db_session.query(TaskExecutionLog).filter( + TaskExecutionLog.task_name == self.task_name, + TaskExecutionLog.run_id == self.run_id, + ) + if entity_id is None: + query = query.filter(TaskExecutionLog.entity_id.is_(None)) + else: + query = query.filter(TaskExecutionLog.entity_id == entity_id) + query.update( + { + "status": STATUS_FAILED, + "error_message": error_message, + "completed_at": datetime.now(timezone.utc), + }, + synchronize_session=False, + ) + self.db_session.flush() + + # ------------------------------------------------------------------ + # Reporting + # ------------------------------------------------------------------ + + def get_summary(self) -> dict: + """ + Return a summary of the run from both task_run and task_execution_log. + + Returns: + { + "task_name": str, + "run_id": str, + "run_status": str, + "total_count": int | None, + "created_at": datetime | None, + "params": dict | None, + "triggered": int, + "completed": int, + "failed": int, + "pending": int, # > 0 means dispatch loop didn't complete; call rebuild again + } + """ + task_run = ( + self.db_session.query(TaskRun) + .filter( + TaskRun.task_name == self.task_name, + TaskRun.run_id == self.run_id, + ) + .first() + ) + if not task_run: + return { + "task_name": self.task_name, + "run_id": self.run_id, + "run_status": None, + "total_count": None, + "created_at": None, + "params": None, + "triggered": 0, + "completed": 0, + "failed": 0, + "pending": 0, + } + + counts: dict[str, int] = { + STATUS_TRIGGERED: 0, + STATUS_COMPLETED: 0, + STATUS_FAILED: 0, + } + rows = ( + self.db_session.query( + TaskExecutionLog.status, + TaskExecutionLog.id, + ) + .filter( + TaskExecutionLog.task_name == self.task_name, + TaskExecutionLog.run_id == self.run_id, + ) + .all() + ) + for row in rows: + if row.status in counts: + counts[row.status] += 1 + + total = task_run.total_count or 0 + processed = ( + counts[STATUS_TRIGGERED] + counts[STATUS_COMPLETED] + counts[STATUS_FAILED] + ) + pending = max(0, total - processed) + + return { + "task_name": self.task_name, + "run_id": self.run_id, + "run_status": task_run.status, + "total_count": task_run.total_count, + "created_at": task_run.created_at, + "params": task_run.params, + "triggered": counts[STATUS_TRIGGERED], + "completed": counts[STATUS_COMPLETED], + "failed": counts[STATUS_FAILED], + "pending": pending, + } + + def schedule_status_sync(self, delay_seconds: int = 0) -> None: + """ + Enqueue a single Cloud Task that will call sync_task_run_status for this run. + + The task name is derived solely from task_name + run_id so the call is fully + idempotent — if a task with this name already exists in the queue Cloud Tasks + returns ALREADY_EXISTS and this method silently skips enqueueing. + + Retries are driven entirely by the queue's retry_config (constant 10-min + intervals). The task handler returns 503 while the run is in progress and + 200 only when complete, so Cloud Tasks knows when to stop retrying. + + Requires env vars: PROJECT_ID, GCP_REGION, ENVIRONMENT, TASK_RUN_SYNC_QUEUE, + SERVICE_ACCOUNT_EMAIL. No-op with a warning when any are missing. + """ + project = os.getenv("PROJECT_ID") + queue = os.getenv("TASK_RUN_SYNC_QUEUE") + gcp_region = os.getenv("GCP_REGION") + environment = os.getenv("ENVIRONMENT") + + if not all([project, queue, gcp_region, environment]): + logging.warning( + "schedule_status_sync: missing env vars (PROJECT_ID/GCP_REGION/" + "ENVIRONMENT/TASK_RUN_SYNC_QUEUE) — skipping Cloud Task enqueue" + ) + return + + try: + from google.cloud import tasks_v2 + from google.protobuf import timestamp_pb2 + from shared.common.gcp_utils import create_http_task_with_name + + safe_name = re.sub( + r"[^a-zA-Z0-9_-]", "-", f"{self.task_name}-{self.run_id}" + ) + task_name = f"sync-{safe_name}"[:500] + + url = ( + f"https://{gcp_region}-{project}.cloudfunctions.net/" + f"tasks_executor-{environment}" + ) + body = json.dumps( + { + "task": "sync_task_run_status", + "payload": { + "task_name": self.task_name, + "run_id": self.run_id, + }, + } + ).encode() + + schedule_time = None + if delay_seconds > 0: + run_at = datetime.now(timezone.utc) + timedelta(seconds=delay_seconds) + schedule_time = timestamp_pb2.Timestamp() + schedule_time.FromDatetime(run_at.replace(tzinfo=None)) + + create_http_task_with_name( + client=tasks_v2.CloudTasksClient(), + body=body, + url=url, + project_id=project, + gcp_region=gcp_region, + queue_name=queue, + task_name=task_name, + task_time=schedule_time, + http_method=tasks_v2.HttpMethod.POST, + ) + logging.info( + "TaskExecutionTracker: enqueued sync task '%s' for %s/%s", + task_name, + self.task_name, + self.run_id, + ) + except Exception as e: + if "already exists" in str(e).lower() or "ALREADY_EXISTS" in str(e): + logging.info( + "TaskExecutionTracker: sync task already queued for %s/%s — skipping", + self.task_name, + self.run_id, + ) + else: + logging.warning( + "TaskExecutionTracker: could not enqueue sync task: %s", e + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _resolve_task_run_id(self) -> Optional[uuid.UUID]: + """Return cached task_run_id or fetch it from DB.""" + if self.task_run_id: + return self.task_run_id + task_run = ( + self.db_session.query(TaskRun) + .filter( + TaskRun.task_name == self.task_name, + TaskRun.run_id == self.run_id, + ) + .first() + ) + if task_run: + self.task_run_id = task_run.id + return self.task_run_id + + def _update_entity_status(self, entity_id: Optional[str], status: str) -> None: + query = self.db_session.query(TaskExecutionLog).filter( + TaskExecutionLog.task_name == self.task_name, + TaskExecutionLog.run_id == self.run_id, + ) + if entity_id is None: + query = query.filter(TaskExecutionLog.entity_id.is_(None)) + else: + query = query.filter(TaskExecutionLog.entity_id == entity_id) + query.update( + {"status": status, "completed_at": datetime.now(timezone.utc)}, + synchronize_session=False, + ) + self.db_session.flush() diff --git a/functions-python/helpers/tests/test_task_execution_tracker.py b/functions-python/helpers/tests/test_task_execution_tracker.py new file mode 100644 index 000000000..f72d22770 --- /dev/null +++ b/functions-python/helpers/tests/test_task_execution_tracker.py @@ -0,0 +1,203 @@ +# +# MobilityData 2026 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import uuid +from datetime import datetime, timezone +from unittest.mock import MagicMock + +from task_execution.task_execution_tracker import ( + TaskExecutionTracker, + STATUS_IN_PROGRESS, + STATUS_TRIGGERED, + STATUS_COMPLETED, + STATUS_FAILED, +) + + +def _make_tracker(task_name="test_task", run_id="v1.0"): + """Return a tracker with a mock DB session.""" + session = MagicMock() + tracker = TaskExecutionTracker( + task_name=task_name, run_id=run_id, db_session=session + ) + return tracker, session + + +class TestTaskExecutionTrackerStartRun(unittest.TestCase): + def test_start_run_upserts_task_run(self): + tracker, session = _make_tracker() + run_uuid = uuid.uuid4() + execute_result = MagicMock() + execute_result.scalar_one.return_value = run_uuid + session.execute.return_value = execute_result + + result = tracker.start_run(total_count=100, params={"env": "staging"}) + + self.assertEqual(result, run_uuid) + self.assertEqual(tracker.task_run_id, run_uuid) + session.execute.assert_called_once() + session.flush.assert_called_once() + + def test_start_run_caches_task_run_id(self): + tracker, session = _make_tracker() + run_uuid = uuid.uuid4() + execute_result = MagicMock() + execute_result.scalar_one.return_value = run_uuid + session.execute.return_value = execute_result + + tracker.start_run(total_count=10) + tracker.start_run(total_count=20) # second call + + self.assertEqual(tracker.task_run_id, run_uuid) + + def test_start_run_resets_status_to_in_progress_on_rerun(self): + """Re-running the same task_name/run_id must reset status and completed_at on conflict.""" + tracker, session = _make_tracker() + run_uuid = uuid.uuid4() + execute_result = MagicMock() + execute_result.scalar_one.return_value = run_uuid + session.execute.return_value = execute_result + + tracker.start_run(total_count=5) + + stmt_compiled = str(session.execute.call_args[0][0]) + # The ON CONFLICT DO UPDATE clause must include status and completed_at + self.assertIn("DO UPDATE SET", stmt_compiled) + self.assertIn("status", stmt_compiled) + self.assertIn("completed_at", stmt_compiled) + + +class TestTaskExecutionTrackerIsTriggered(unittest.TestCase): + def test_returns_true_when_triggered_row_exists(self): + tracker, session = _make_tracker() + existing_row = MagicMock() + session.query.return_value.filter.return_value.filter.return_value.first.return_value = ( + existing_row + ) + + result = tracker.is_triggered("ds-123") + self.assertTrue(result) + + def test_returns_false_when_no_row(self): + tracker, session = _make_tracker() + session.query.return_value.filter.return_value.filter.return_value.first.return_value = ( + None + ) + + result = tracker.is_triggered("ds-999") + self.assertFalse(result) + + def test_handles_none_entity_id(self): + tracker, session = _make_tracker() + session.query.return_value.filter.return_value.filter.return_value.first.return_value = ( + None + ) + + result = tracker.is_triggered(None) + self.assertFalse(result) + + +class TestTaskExecutionTrackerMarkTriggered(unittest.TestCase): + def test_mark_triggered_inserts_execution_log(self): + tracker, session = _make_tracker() + tracker.task_run_id = uuid.uuid4() + + tracker.mark_triggered("ds-1", execution_ref="projects/x/executions/abc") + + session.execute.assert_called_once() + session.flush.assert_called_once() + + def test_mark_triggered_with_metadata(self): + tracker, session = _make_tracker() + tracker.task_run_id = uuid.uuid4() + + tracker.mark_triggered("ds-1", metadata={"feed_id": "f-1"}) + + session.execute.assert_called_once() + + +class TestTaskExecutionTrackerMarkCompleted(unittest.TestCase): + def test_mark_completed_updates_status(self): + tracker, session = _make_tracker() + query_mock = MagicMock() + session.query.return_value.filter.return_value.filter.return_value = query_mock + + tracker.mark_completed("ds-1") + + query_mock.update.assert_called_once() + update_args = query_mock.update.call_args[0][0] + self.assertEqual(update_args["status"], STATUS_COMPLETED) + self.assertIn("completed_at", update_args) + + +class TestTaskExecutionTrackerMarkFailed(unittest.TestCase): + def test_mark_failed_sets_error_message(self): + tracker, session = _make_tracker() + query_mock = MagicMock() + session.query.return_value.filter.return_value.filter.return_value = query_mock + + tracker.mark_failed("ds-1", error_message="Workflow timed out") + + query_mock.update.assert_called_once() + update_args = query_mock.update.call_args[0][0] + self.assertEqual(update_args["status"], STATUS_FAILED) + self.assertEqual(update_args["error_message"], "Workflow timed out") + + +class TestTaskExecutionTrackerGetSummary(unittest.TestCase): + def _make_task_run(self, status=STATUS_IN_PROGRESS, total_count=10): + run = MagicMock() + run.status = status + run.total_count = total_count + run.created_at = datetime.now(timezone.utc) + return run + + def test_returns_none_summary_when_no_run(self): + tracker, session = _make_tracker() + session.query.return_value.filter.return_value.first.return_value = None + session.query.return_value.filter.return_value.all.return_value = [] + + summary = tracker.get_summary() + + self.assertIsNone(summary["run_status"]) + self.assertEqual(summary["triggered"], 0) + self.assertEqual(summary["completed"], 0) + + def test_counts_by_status(self): + tracker, session = _make_tracker() + task_run = self._make_task_run(total_count=5) + + rows = [ + MagicMock(status=STATUS_TRIGGERED), + MagicMock(status=STATUS_TRIGGERED), + MagicMock(status=STATUS_COMPLETED), + MagicMock(status=STATUS_FAILED), + ] + + def query_side_effect(*args): + m = MagicMock() + m.filter.return_value.first.return_value = task_run + m.filter.return_value.all.return_value = rows + return m + + session.query.side_effect = query_side_effect + + summary = tracker.get_summary() + self.assertEqual(summary["triggered"], 2) + self.assertEqual(summary["completed"], 1) + self.assertEqual(summary["failed"], 1) + self.assertEqual(summary["pending"], 1) # 5 total - 4 processed diff --git a/functions-python/helpers/validation_report/validation_report_update.py b/functions-python/helpers/validation_report/validation_report_update.py index 684c2a63c..ff5beb5c1 100644 --- a/functions-python/helpers/validation_report/validation_report_update.py +++ b/functions-python/helpers/validation_report/validation_report_update.py @@ -47,13 +47,18 @@ def execute_workflows( validator_endpoint=None, bypass_db_update=False, reports_bucket_name=None, + tracker=None, ): """ - Execute the workflow for the latest datasets that need their validation report to be updated + Execute the workflow for the latest datasets that need their validation report to be updated. + :param latest_datasets: List of tuples containing the feed stable id and dataset stable id :param validator_endpoint: The URL of the validator :param bypass_db_update: Whether to bypass the database update :param reports_bucket_name: The name of the bucket where the reports are stored + :param tracker: Optional TaskExecutionTracker for idempotent execution tracking. + When provided, datasets already in triggered/completed state are skipped + and newly triggered datasets are recorded. :return: List of dataset stable ids for which the workflow was executed """ project_id = f"mobility-feeds-{env}" @@ -64,6 +69,9 @@ def execute_workflows( count = 0 logging.info(f"Executing workflow for {len(latest_datasets)} datasets") for feed_id, dataset_id in latest_datasets: + if tracker and tracker.is_triggered(dataset_id): + logging.info(f"Skipping already triggered dataset {feed_id}/{dataset_id}") + continue try: input_data = { "data": { @@ -83,12 +91,20 @@ def execute_workflows( if reports_bucket_name: input_data["data"]["reports_bucket_name"] = reports_bucket_name logging.info(f"Executing workflow for {feed_id}/{dataset_id}") - execute_workflow(project_id, input_data=input_data) + execution = execute_workflow(project_id, input_data=input_data) execution_triggered_datasets.append(dataset_id) + if tracker: + tracker.mark_triggered( + entity_id=dataset_id, + execution_ref=execution.name, + metadata={"feed_id": feed_id}, + ) except Exception as e: logging.error( f"Error while executing workflow for {feed_id}/{dataset_id}: {e}" ) + if tracker: + tracker.mark_failed(entity_id=dataset_id, error_message=str(e)) count += 1 logging.info(f"Triggered workflow execution for {count} datasets") if count % batch_size == 0: diff --git a/functions-python/process_validation_report/src/main.py b/functions-python/process_validation_report/src/main.py index 3d45afc9c..b7e2bc9f6 100644 --- a/functions-python/process_validation_report/src/main.py +++ b/functions-python/process_validation_report/src/main.py @@ -35,6 +35,7 @@ from shared.helpers.logger import init_logger from shared.helpers.transform import get_nested_value from shared.helpers.feed_status import update_feed_statuses_query +from shared.helpers.task_execution.task_execution_tracker import TaskExecutionTracker from shared.common.gcp_utils import create_web_revalidation_task init_logger() @@ -288,6 +289,20 @@ def create_validation_report_entities( update_feed_statuses_query(db_session, [feed_stable_id]) + # Update execution tracker regardless of bypass_db_update, so monitoring + # works for both pre-release and post-release validation runs. + try: + tracker = TaskExecutionTracker( + task_name="gtfs_validation", + run_id=version, + db_session=db_session, + ) + tracker.mark_completed(dataset_stable_id) + db_session.commit() + except Exception as tracker_error: + logging.warning( + "Could not update task execution tracker: %s", tracker_error + ) # Trigger web app cache revalidation for the feed try: create_web_revalidation_task([feed_stable_id]) diff --git a/functions-python/tasks_executor/README.md b/functions-python/tasks_executor/README.md index 150b84de6..49cfbd102 100644 --- a/functions-python/tasks_executor/README.md +++ b/functions-python/tasks_executor/README.md @@ -20,12 +20,26 @@ Examples: "task": "rebuild_missing_validation_reports", "payload": { "dry_run": true, - "filter_after_in_days": 14, + "bypass_db_update": true, + "filter_after_in_days": null, + "force_update": false, + "validator_endpoint": "https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app", + "limit": 1, "filter_statuses": ["active", "inactive", "future"] } } ``` +```json +{ + "task": "get_validation_run_status", + "payload": { + "task_name": "gtfs_validation", + "run_id": "7.1.1-SNAPSHOT" + } +} +``` + ```json { "task": "rebuild_missing_bounding_boxes", diff --git a/functions-python/tasks_executor/src/main.py b/functions-python/tasks_executor/src/main.py index 7182a3944..55c6e9ba4 100644 --- a/functions-python/tasks_executor/src/main.py +++ b/functions-python/tasks_executor/src/main.py @@ -21,6 +21,7 @@ import functions_framework from shared.helpers.logger import init_logger +from shared.helpers.task_execution.task_execution_tracker import TaskInProgressError from tasks.data_import.transportdatagouv.import_tdg_feeds import import_tdg_handler from tasks.data_import.transportdatagouv.update_tdg_redirects import ( update_tdg_redirects_handler, @@ -38,6 +39,12 @@ from tasks.validation_reports.rebuild_missing_validation_reports import ( rebuild_missing_validation_reports_handler, ) +from tasks.sync_task_run_status import ( + sync_task_run_status_handler, +) +from tasks.get_task_run_status import ( + get_task_run_status_handler, +) from tasks.visualization_files.rebuild_missing_visualization_files import ( rebuild_missing_visualization_files_handler, ) @@ -71,6 +78,25 @@ "description": "Rebuilds missing validation reports for GTFS datasets.", "handler": rebuild_missing_validation_reports_handler, }, + "get_task_run_status": { + "description": ( + "Read-only snapshot of a task_run tracked by TaskExecutionTracker. " + "Returns current DB state (triggered/completed/failed/pending counts) " + "without triggering any GCP Workflows polling or status transitions. " + "Required: task_name, run_id." + ), + "handler": get_task_run_status_handler, + }, + "sync_task_run_status": { + "description": ( + "Generic self-scheduling monitor for any task_run. " + "Polls GCP Workflows for triggered entries, updates statuses, " + "marks the task_run completed when all done, and re-schedules " + "itself every 10 minutes until complete. " + "Required: task_name, run_id." + ), + "handler": sync_task_run_status_handler, + }, "rebuild_missing_bounding_boxes": { "description": "Rebuilds missing bounding boxes for GTFS datasets that contain valid stops.txt files.", "handler": rebuild_missing_bounding_boxes_handler, @@ -195,5 +221,10 @@ def tasks_executor(request: flask.Request) -> flask.Response: # Default JSON response return flask.make_response(flask.jsonify(result), 200) + except TaskInProgressError as error: + # Signal Cloud Tasks to retry — the run is not yet complete + return flask.make_response( + flask.jsonify({"status": "in_progress", "detail": str(error)}), 503 + ) except Exception as error: return flask.make_response(flask.jsonify({"error": str(error)}), 500) diff --git a/functions-python/tasks_executor/src/tasks/get_task_run_status.py b/functions-python/tasks_executor/src/tasks/get_task_run_status.py new file mode 100644 index 000000000..3d8a68eea --- /dev/null +++ b/functions-python/tasks_executor/src/tasks/get_task_run_status.py @@ -0,0 +1,92 @@ +# +# MobilityData 2026 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Task: get_task_run_status + +Read-only snapshot of a task_run tracked by TaskExecutionTracker. + +Returns the current DB state for the given (task_name, run_id) pair without +triggering any GCP Workflows polling or status transitions. Use this task to +inspect a run at any point — before, during, or after it completes. + +For active status syncing (polling GCP Workflows and driving run completion) +use sync_task_run_status instead. + +Payload: + { + "task_name": str, # required — e.g. "gtfs_validation" + "run_id": str, # required — e.g. "7.1.1-SNAPSHOT" + } +""" + +from sqlalchemy.orm import Session + +from shared.database.database import with_db_session +from shared.helpers.task_execution.task_execution_tracker import TaskExecutionTracker + + +def get_task_run_status_handler(payload: dict) -> dict: + """ + Entry point for the get_task_run_status task. + + Payload structure: + { + "task_name": str, # required + "run_id": str, # required + } + + Returns a status summary dict. Never raises TaskInProgressError — this task + is always read-only and always returns HTTP 200. + """ + task_name = payload.get("task_name") + run_id = payload.get("run_id") + if not task_name or not run_id: + raise ValueError("task_name and run_id are required") + + return get_task_run_status(task_name=task_name, run_id=run_id) + + +@with_db_session +def get_task_run_status( + task_name: str, + run_id: str, + db_session: Session | None = None, +) -> dict: + """ + Return a snapshot of the task run's current state from the DB. + + Response fields: + task_name — the task name + run_id — the run identifier + run_status — task_run.status (in_progress / completed / failed / None if not found) + total_count — number of entities registered at dispatch time + triggered — count with status=triggered (workflows still running) + completed — count with status=completed + failed — count with status=failed + pending — total_count minus all logged entries (dispatch not yet complete) + dispatch_complete — True when pending == 0 (all entities have been dispatched) + created_at — when the task_run was first created + params — params dict stored at start_run() time + """ + tracker = TaskExecutionTracker( + task_name=task_name, + run_id=run_id, + db_session=db_session, + ) + summary = tracker.get_summary() + summary["dispatch_complete"] = summary["pending"] == 0 + return summary diff --git a/functions-python/tasks_executor/src/tasks/sync_task_run_status.py b/functions-python/tasks_executor/src/tasks/sync_task_run_status.py new file mode 100644 index 000000000..6c5d2a95b --- /dev/null +++ b/functions-python/tasks_executor/src/tasks/sync_task_run_status.py @@ -0,0 +1,216 @@ +# +# MobilityData 2025 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Task: sync_task_run_status + +Generic Cloud-Tasks-driven monitor for any task_run tracked by TaskExecutionTracker. + +One Cloud Task is enqueued by TaskExecutionTracker.schedule_status_sync() when a run +starts. The queue's retry_config drives the polling cadence (constant 10-minute +intervals). This task handler: + + - Polls the GCP Workflows Executions API for entries still in 'triggered' state + - Updates task_execution_log statuses (triggered → completed / failed) + - Returns HTTP 200 (via normal return) when the run is fully settled + - Raises TaskInProgressError → HTTP 503 when still in progress, signalling Cloud + Tasks to retry after the configured backoff + +No self-scheduling logic — the Cloud Tasks queue manages the retry loop. + +Payload: + { + "task_name": str, # required — e.g. "gtfs_validation" + "run_id": str, # required — e.g. "7.1.1-SNAPSHOT" + } +""" + +import logging + +from google.cloud.workflows import executions_v1 +from sqlalchemy.orm import Session + +from shared.database.database import with_db_session +from shared.database_gen.sqlacodegen_models import TaskExecutionLog +from shared.helpers.task_execution.task_execution_tracker import ( + TaskExecutionTracker, + TaskInProgressError, + STATUS_TRIGGERED, + STATUS_FAILED, + STATUS_COMPLETED, +) + + +def sync_task_run_status_handler(payload: dict) -> dict: + """ + Entry point for the sync_task_run_status task. + + Payload structure: + { + "task_name": str, # required + "run_id": str, # required + } + + Returns the run summary on completion (HTTP 200). + Raises TaskInProgressError (→ HTTP 503) when still in progress so Cloud Tasks + retries according to the queue's retry_config. + """ + task_name = payload.get("task_name") + run_id = payload.get("run_id") + if not task_name or not run_id: + raise ValueError("task_name and run_id are required") + + return sync_task_run_status(task_name=task_name, run_id=run_id) + + +@with_db_session +def sync_task_run_status( + task_name: str, + run_id: str, + db_session: Session | None = None, +) -> dict: + """ + Sync execution statuses and mark the task_run completed when all done. + + Raises TaskInProgressError if the run is not yet complete so the Cloud Tasks + queue retries this task after the configured backoff (default: 10 minutes). + """ + tracker = TaskExecutionTracker( + task_name=task_name, + run_id=run_id, + db_session=db_session, + ) + + _sync_workflow_statuses(task_name, run_id, db_session, tracker) + db_session.commit() + + summary = tracker.get_summary() + summary["dispatch_complete"] = summary["pending"] == 0 + + run_params = summary.get("params") or {} + summary["total_candidates"] = run_params.get("total_candidates") + + failed_entries = ( + db_session.query(TaskExecutionLog) + .filter( + TaskExecutionLog.task_name == task_name, + TaskExecutionLog.run_id == run_id, + TaskExecutionLog.status == STATUS_FAILED, + ) + .all() + ) + summary["failed_entity_ids"] = [e.entity_id for e in failed_entries] + + all_settled = ( + summary["dispatch_complete"] + and summary["triggered"] == 0 + and summary["failed"] == 0 + ) + summary["ready_for_bigquery"] = all_settled + + if all_settled: + tracker.finish_run(STATUS_COMPLETED) + db_session.commit() + summary["run_status"] = STATUS_COMPLETED + logging.info( + "sync_task_run_status: run %s/%s complete — task_run marked completed", + task_name, + run_id, + ) + return summary + + # Not done yet — raise so Cloud Tasks retries after the queue backoff + logging.info( + "sync_task_run_status: run %s/%s still in progress " + "(pending=%s, triggered=%s, failed=%s) — returning 503 for retry", + task_name, + run_id, + summary["pending"], + summary["triggered"], + summary["failed"], + ) + raise TaskInProgressError( + f"Run {task_name}/{run_id} still in progress: " + f"pending={summary['pending']}, triggered={summary['triggered']}, " + f"failed={summary['failed']}" + ) + + +def _sync_workflow_statuses( + task_name: str, + run_id: str, + db_session: Session, + tracker: TaskExecutionTracker, +) -> None: + """ + Poll GCP Workflows Executions API for all 'triggered' entries with an + execution_ref and update task_execution_log accordingly. + """ + triggered_entries = ( + db_session.query(TaskExecutionLog) + .filter( + TaskExecutionLog.task_name == task_name, + TaskExecutionLog.run_id == run_id, + TaskExecutionLog.status == STATUS_TRIGGERED, + TaskExecutionLog.execution_ref.isnot(None), + ) + .all() + ) + + if not triggered_entries: + logging.info( + "sync_task_run_status: no triggered entries with execution_ref for %s/%s", + task_name, + run_id, + ) + return + + logging.info( + "sync_task_run_status: syncing %s triggered executions via GCP Workflows API", + len(triggered_entries), + ) + client = executions_v1.ExecutionsClient() + + for entry in triggered_entries: + try: + execution = client.get_execution(request={"name": entry.execution_ref}) + state = execution.state + + if state == executions_v1.Execution.State.SUCCEEDED: + tracker.mark_completed(entry.entity_id) + logging.info( + "Execution %s SUCCEEDED for entity %s", + entry.execution_ref, + entry.entity_id, + ) + elif state in ( + executions_v1.Execution.State.FAILED, + executions_v1.Execution.State.CANCELLED, + ): + error_msg = getattr(execution.error, "payload", str(state)) + tracker.mark_failed(entry.entity_id, error_message=error_msg) + logging.warning( + "Execution %s %s for entity %s: %s", + entry.execution_ref, + state.name, + entry.entity_id, + error_msg, + ) + # ACTIVE / QUEUED → still running, leave as triggered + except Exception as e: + logging.error( + "Error fetching execution status for %s: %s", entry.execution_ref, e + ) diff --git a/functions-python/tasks_executor/src/tasks/validation_reports/README.md b/functions-python/tasks_executor/src/tasks/validation_reports/README.md index 9312ba8ba..83dcfbee5 100644 --- a/functions-python/tasks_executor/src/tasks/validation_reports/README.md +++ b/functions-python/tasks_executor/src/tasks/validation_reports/README.md @@ -1,30 +1,283 @@ -# Rebuild Missing Validation Reports +# GTFS Validation Report Tasks -This task generates the missing reports in the GTFS datasets. -The reports are generated using the _gtfs_validator_ GCP workflow. +This module contains two tasks for managing GTFS validation reports at scale: -## Task ID -Use task Id: `rebuild_missing_validation_reports` +| Task ID | Purpose | +|---|---| +| `rebuild_missing_validation_reports` | Triggers GCP Workflows to (re)validate datasets | +| `sync_task_run_status` | Generic self-scheduling monitor for any task_run | -## Usage -The function receive the following payload: +--- + +## Architecture + +```mermaid +sequenceDiagram + participant Caller as Caller + participant Rebuild as rebuild_missing_validation_reports + participant DB as PostgreSQL + participant GCS as GCS (dataset zips) + participant CQ as Cloud Tasks Queue + participant Sync as sync_task_run_status + participant WF as GCP Workflow + GTFS Validator + participant PVR as process_validation_report + + Caller->>Rebuild: POST payload + + Rebuild->>DB: Query datasets needing validation + Rebuild->>GCS: Filter to datasets with existing zip blob + + alt dry_run = true + Rebuild-->>Caller: counts only, no side effects + else dry_run = false + Rebuild->>DB: Upsert task_run, record triggered datasets + Rebuild->>CQ: Enqueue sync task (fires in 10 min, idempotent) + Rebuild->>WF: Trigger one Workflow per dataset + Rebuild-->>Caller: triggered / skipped counts + end + + loop every 10 min until complete + CQ->>Sync: Poll run status + Sync->>WF: Check execution states + Sync->>DB: Update task_execution_log + alt all settled + Sync->>DB: Mark task_run completed + Sync-->>CQ: HTTP 200 — done + else still running + Sync-->>CQ: HTTP 503 — retry + end + end + + WF->>PVR: Report URL + dataset metadata + alt bypass_db_update = false + PVR->>DB: Write validation report + mark completed + else bypass_db_update = true (pre-release) + PVR->>DB: Mark completed only (report not surfaced) + end +``` + +--- + +## `rebuild_missing_validation_reports` + +Finds GTFS datasets that are missing a validation report **or** have a report from an +older validator version, then triggers a GCP Workflow for each one. + +The task is **resumable**: if it times out mid-loop, calling it again skips datasets +that were already triggered (tracked in `task_execution_log`). + +### Payload + +```json +{ + "dry_run": true, + "validator_endpoint": "https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app", + "bypass_db_update": false, + "filter_after_in_days": 30, + "filter_statuses": ["active"], + "filter_op_statuses": ["published"], + "force_update": false, + "limit": 10 +} +``` + +| Field | Type | Default | Description | +|---|---|---|---| +| `dry_run` | bool | `true` | Count candidates only — no workflows triggered | +| `validator_endpoint` | string | env-derived | Validator service URL to use and fetch version from | +| `bypass_db_update` | bool | `false` | When `true`, results are NOT written to DB/API (use for pre-release runs) | +| `filter_after_in_days` | int | `null` | Restrict to datasets downloaded within the last N days. Omit to include all datasets | +| `filter_statuses` | list[str] | `null` | Filter feeds by status (e.g. `["active", "inactive"]`). Omit for all statuses | +| `filter_op_statuses` | list[str] | `["published"]` | Filter feeds by operational status. Accepted values: `"published"`, `"unpublished"`, `"wip"` | +| `force_update` | bool | `false` | Re-trigger even when a current report already exists | +| `limit` | int | `null` | Cap the number of workflows triggered per call — useful for end-to-end testing | + +--- + +## `sync_task_run_status` + +Generic self-scheduling monitor for any `task_run` tracked by `TaskExecutionTracker`. +Automatically scheduled by `rebuild_missing_validation_reports` on every non-dry run. + +### Behaviour + +1. Polls GCP Workflows Executions API for all `triggered` entries with an `execution_ref` +2. Updates statuses (`triggered → completed / failed`) +3. If all done → marks `task_run.status = 'completed'` +4. If still in progress → re-schedules itself as a Cloud Task after `sync_delay_seconds` + +### Payload + +```json +{ + "task_name": "gtfs_validation", + "run_id": "7.0.0", + "sync_delay_seconds": 600 +} +``` + +| Field | Type | Default | Description | +|---|---|---|---| +| `task_name` | string | **required** | Task name identifier (e.g. `"gtfs_validation"`) | +| `run_id` | string | **required** | Run identifier — the validator version string | +| `sync_delay_seconds` | int | `600` | Seconds between polling cycles | + +### Response + +Same as `get_validation_run_status` (which this task replaces): + +```json +{ + "task_name": "gtfs_validation", + "run_id": "7.0.0", + "run_status": "in_progress", + "total_count": 5000, + "total_candidates": 5000, + "dispatch_complete": true, + "triggered": 200, + "completed": 4800, + "failed": 0, + "pending": 0, + "failed_entity_ids": [], + "ready_for_bigquery": false +} +``` + +| Field | Meaning | +|---|---| +| `total_count` | Datasets intended to be triggered in the current call (respects `limit`) | +| `total_candidates` | Total datasets needing validation (before `limit` slicing) | +| `dispatch_complete` | `false` → `rebuild_missing_validation_reports` timed out; call it again | +| `pending` | Datasets not yet triggered (`> 0` means dispatch loop is incomplete) | +| `triggered` | Dispatched but report not yet processed | +| `ready_for_bigquery` | `true` when all workflows finished with no failures and task_run is marked completed | + +--- + +## Pre-release Validator Analytics — Step-by-Step + +This runbook generates analytics for a **new validator version** (pre-release) without +surfacing results in the public API (`bypass_db_update=true`). + +### Prerequisites + +- The staging validator is deployed at `https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app` +- You have the `validator_version` string (fetch from `/version`) + +### Step 1 — Dry run (estimate scope) + +```json +{ + "task": "rebuild_missing_validation_reports", + "payload": { + "dry_run": true, + "validator_endpoint": "https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app", + "bypass_db_update": true + } +} ``` - { - "dry_run": bool, # [optional] If True, do not execute the workflow - "filter_after_in_days": int, # [optional] Filter datasets older than this number of days(default: 14 days ago) - "filter_statuses": list[str] # [optional] Filter datasets by status(in) + +Check `total_candidates` in the response to understand the scale. + +### Step 2 — End-to-end test with a small batch + +```json +{ + "task": "rebuild_missing_validation_reports", + "payload": { + "dry_run": false, + "validator_endpoint": "https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app", + "bypass_db_update": true, + "limit": 10 } +} ``` -Example: + +### Step 3 — Monitor the test batch + +```json +{ + "task": "get_validation_run_status", + "payload": { + "validator_version": "7.0.0", + "sync_workflow_status": true + } +} ``` + +### Step 3 — Monitor the test batch + +`sync_task_run_status` is scheduled automatically by `rebuild_missing_validation_reports`. +You can also call it on demand to get the current status: + +```json { - "dry_run": true, - "filter_after_in_days": 14, - "filter_statuses": ["active", "inactive", "future"] + "task": "sync_task_run_status", + "payload": { + "task_name": "gtfs_validation", + "run_id": "7.0.0" + } +} +``` + +Verify `dispatch_complete: true` and `triggered` count decreases as workflows finish. + +### Step 4 — Full run + +Remove the `limit`. If the function times out, call it again — already-triggered +datasets are automatically skipped. The self-scheduling `sync_task_run_status` continues +polling in the background every 10 minutes: + +```json +{ + "task": "rebuild_missing_validation_reports", + "payload": { + "dry_run": false, + "validator_endpoint": "https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app", + "bypass_db_update": true + } } ``` -# GCP environment variables -The function uses the following environment variables: -- `ENV`: The environment to use. It can be `dev`, `staging` or `prod`. Default is `dev`. -- `LOCATION`: The location of the GCP project. Default is `northamerica-northeast1`. +### Step 5 — Wait for completion + +`sync_task_run_status` runs automatically every 10 minutes. The run is fully complete +when `ready_for_bigquery: true` (`dispatch_complete=true`, `pending=0`, `triggered=0`, +`failed=0`) and `task_run.status` is set to `completed`. + +To check on demand: + +```json +{ + "task": "sync_task_run_status", + "payload": { + "task_name": "gtfs_validation", + "run_id": "7.0.0" + } +} +``` + +### Step 6 — BigQuery ingestion + +BigQuery ingestion runs on a fixed schedule (2nd of each month). To ingest immediately +after the pre-release run completes, trigger the `ingest-data-to-big-query` Cloud +Function manually: + +```bash +curl -X POST "https://ingest-data-to-big-query-gtfs-563580583640.northamerica-northeast1.run.app" \ + -H "Authorization: bearer $(gcloud auth print-identity-token)" \ + -H "Content-Type: application/json" +``` + +--- + +## GCP Environment Variables + +| Variable | Default | Description | +|---|---|---| +| `ENV` | `dev` | Environment (`dev`, `staging`, `prod`) | +| `LOCATION` | `northamerica-northeast1` | GCP region | +| `GTFS_VALIDATOR_URL` | env-derived | Override the validator URL (takes priority over `ENV`) | +| `BATCH_SIZE` | `5` | Number of workflows triggered per batch before sleeping | +| `SLEEP_TIME` | `5` | Seconds to sleep between batches | +| `TASK_RUN_SYNC_QUEUE` | Terraform-injected | Cloud Tasks queue used by `sync_task_run_status` self-scheduling | diff --git a/functions-python/tasks_executor/src/tasks/validation_reports/rebuild_missing_validation_reports.py b/functions-python/tasks_executor/src/tasks/validation_reports/rebuild_missing_validation_reports.py index a5a1b1e96..a8d429d2e 100644 --- a/functions-python/tasks_executor/src/tasks/validation_reports/rebuild_missing_validation_reports.py +++ b/functions-python/tasks_executor/src/tasks/validation_reports/rebuild_missing_validation_reports.py @@ -16,151 +16,361 @@ import logging import os +import requests from datetime import datetime, timedelta -from typing import List, Final +from typing import List, Final, Optional +from google.cloud import storage +from sqlalchemy import or_ from sqlalchemy.orm import Session from shared.database.database import with_db_session -from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsdataset +from shared.database_gen.sqlacodegen_models import ( + Feed, + Gtfsfeed, + Gtfsdataset, + Validationreport, +) from shared.helpers.gtfs_validator_common import ( get_gtfs_validator_results_bucket, get_gtfs_validator_url, ) -from shared.helpers.query_helper import get_datasets_with_missing_reports_query +from shared.helpers.task_execution.task_execution_tracker import TaskExecutionTracker from shared.helpers.validation_report.validation_report_update import execute_workflows QUERY_LIMIT: Final[int] = 100 +TASK_NAME: Final[str] = "gtfs_validation" + +env = os.getenv("ENV", "dev").lower() +datasets_bucket_name = f"mobilitydata-datasets-{env}" def rebuild_missing_validation_reports_handler(payload) -> dict: """ Rebuilds missing validation reports for GTFS datasets. - This function processes datasets with missing validation reports using the GTFS validator workflow. - The payload structure is: + + Handles two cases: + 1. Datasets with no validation report at all (ongoing maintenance). + 2. Datasets whose existing report was generated by an older validator version + (triggered by passing validator_endpoint, used during validator releases). + + Payload structure: { - "dry_run": bool, # [optional] If True, do not execute the workflow - "filter_after_in_days": int, # [optional] Filter datasets older than this number of days(default: 14 days ago) - "filter_statuses": list[str] # [optional] Filter datasets by status(in) + "dry_run": bool, # [optional] If True, count only — do not trigger workflows. Default: True + "filter_after_in_days": int, # [optional] Restrict to datasets downloaded within the last N days. + # If omitted, all datasets are considered regardless of age. + "filter_statuses": list[str],# [optional] Filter feeds by status + "filter_op_statuses": list[str],# [optional] Filter feeds by operational status. + # Default: ["published"] + # Accepted values: "published", "unpublished", "wip" + "validator_endpoint": str, # [optional] Override validator URL (e.g. staging). Default: env-derived URL. + "bypass_db_update": bool, # [optional] If True, results are NOT written to the DB/API (pre-release runs). + Default: False + "force_update": bool, # [optional] Re-trigger even if a report already exists. Default: False + "limit": int, # [optional] Max datasets to trigger per call (for testing). Default: unlimited } - Args: - payload (dict): The payload containing the task details. - Returns: - str: A message indicating the result of the operation with the total_processed datasets. """ ( dry_run, filter_after_in_days, filter_statuses, + filter_op_statuses, prod_env, validator_endpoint, + bypass_db_update, + force_update, + limit, ) = get_parameters(payload) return rebuild_missing_validation_reports( validator_endpoint=validator_endpoint, + bypass_db_update=bypass_db_update, dry_run=dry_run, filter_after_in_days=filter_after_in_days, filter_statuses=filter_statuses, + filter_op_statuses=filter_op_statuses, prod_env=prod_env, + force_update=force_update, + limit=limit, ) @with_db_session def rebuild_missing_validation_reports( validator_endpoint: str, + bypass_db_update: bool = False, dry_run: bool = True, - filter_after_in_days: int = 14, + filter_after_in_days: Optional[int] = None, filter_statuses: List[str] | None = None, + filter_op_statuses: List[str] | None = None, prod_env: bool = False, + force_update: bool = False, + limit: Optional[int] = None, db_session: Session | None = None, ) -> dict: """ - Rebuilds missing validation reports for GTFS datasets. + Rebuilds missing (or stale) validation reports for GTFS datasets. Args: - validator_endpoint: Validator endpoint URL - dry_run (bool): dry run flag. If True, do not execute the workflow. Default: True - filter_after_in_days (int): Filter the datasets older than this number of days. Default: 14 days ago - filter_statuses: [optional] Filter datasets by status(in). Default: None - prod_env (bool): True if target environment is production, false otherwise. Default: False - db_session: DB session - - Returns: - flask.Response: A response with message and total_processed datasets. + validator_endpoint: Validator service URL (default: env-derived) + bypass_db_update: If True, validation results are NOT written to the DB/API. + Use for pre-release runs where results should not be surfaced yet. + dry_run: If True, count only — do not trigger workflows. Default: True + filter_after_in_days: Restrict to datasets downloaded within the last N days. + If None (default), all datasets are considered regardless of age. + filter_statuses: Filter feeds by status. Default: None (all) + filter_op_statuses: Filter feeds by operational status. + Default: ["published"]. Accepted: "published", "unpublished", "wip". + prod_env: True if targeting the production environment. Default: False + force_update: Re-trigger even if a report already exists. Default: False + limit: Max datasets to trigger per call (for end-to-end testing). Default: unlimited + db_session: DB session (injected by @with_db_session) """ - filter_after = datetime.today() - timedelta(days=filter_after_in_days) - query = get_datasets_with_missing_reports_query(db_session, filter_after) - if filter_statuses: - query = query.filter(Gtfsfeed.status.in_(filter_statuses)) - # Having a snapshot of datasets ids as the execution of the workflow - # can potentially add reports while this function is still running. - # This scenario will make the pagination result inconsistent. - dataset_ids = [row[0] for row in query.with_entities(Gtfsdataset.id).all()] - - total_processed = 0 - limit = QUERY_LIMIT - offset = 0 - - for i in range(0, len(dataset_ids), limit): - batch_ids = dataset_ids[i : i + limit] - datasets = ( - db_session.query(Gtfsfeed.stable_id, Gtfsdataset.stable_id) - .select_from(Gtfsfeed) - .join(Gtfsdataset, Gtfsdataset.feed_id == Gtfsfeed.id) - .filter(Gtfsdataset.id.in_(batch_ids)) - .all() + validator_version = _get_validator_version(validator_endpoint) + logging.info( + "Validator version: %s (bypass_db_update=%s)", + validator_version, + bypass_db_update, + ) + + datasets = _get_datasets_for_validation( + db_session=db_session, + validator_version=validator_version, + force_update=force_update, + filter_after_in_days=filter_after_in_days, + filter_statuses=filter_statuses, + filter_op_statuses=filter_op_statuses + if filter_op_statuses is not None + else ["published"], + ) + total_candidates = len(datasets) + logging.info("Found %s candidate datasets", total_candidates) + + # Apply limit inside the GCS blob check so we stop as soon as we have + # enough valid datasets, without discarding candidates that would pass. + valid_datasets = _filter_out_datasets_without_blob(datasets, limit=limit) + logging.info( + "%s datasets have a GCS blob and will be triggered", len(valid_datasets) + ) + + tracker = TaskExecutionTracker( + task_name=TASK_NAME, + run_id=validator_version, + db_session=db_session, + ) + + datasets_to_trigger = valid_datasets + + if not dry_run: + tracker.start_run( + total_count=len(datasets_to_trigger), + params={ + "validator_endpoint": validator_endpoint, + "bypass_db_update": bypass_db_update, + "prod_env": prod_env, + "force_update": force_update, + "limit": limit, + "total_candidates": total_candidates, + }, ) - logging.info("Found %s datasets, offset %s", len(datasets), offset) - - if not dry_run: - execute_workflows( - datasets, - validator_endpoint=validator_endpoint, - bypass_db_update=False, - reports_bucket_name=get_gtfs_validator_results_bucket(prod_env), - ) - else: - logging.debug("Dry run: %s datasets would be processed", datasets) - total_processed += len(datasets) + tracker.schedule_status_sync(delay_seconds=600) + + total_triggered = 0 + total_skipped = 0 + total_already_tracked = 0 + + if dry_run: + entity_ids = [d[1] for d in datasets_to_trigger] + total_already_tracked = tracker.count_already_tracked(entity_ids) + logging.info( + "Dry run: %s datasets would be triggered (%s already tracked, would be skipped)", + len(datasets_to_trigger), + total_already_tracked, + ) + else: + triggered_ids = execute_workflows( + datasets_to_trigger, + validator_endpoint=validator_endpoint, + bypass_db_update=bypass_db_update, + reports_bucket_name=get_gtfs_validator_results_bucket(prod_env), + tracker=tracker, + ) + total_triggered = len(triggered_ids) + total_skipped = len(datasets_to_trigger) - total_triggered + db_session.commit() message = ( - "Dry run: no datasets processed." + f"Dry run: {len(datasets_to_trigger)} datasets would be triggered " + f"({total_already_tracked} already tracked, would be skipped)." if dry_run - else "Rebuild missing validation reports task executed successfully." + else f"Triggered {total_triggered} workflows ({total_skipped} skipped — already tracked)." ) result = { "message": message, - "total_processed": total_processed, + "total_candidates": total_candidates, + "total_in_call": len(datasets_to_trigger), + "total_triggered": 0 if dry_run else total_triggered, + "total_skipped": 0 if dry_run else total_skipped, + "total_already_tracked": total_already_tracked, "params": { "dry_run": dry_run, + "validator_version": validator_version, + "validator_endpoint": validator_endpoint, + "bypass_db_update": bypass_db_update, "filter_after_in_days": filter_after_in_days, "filter_statuses": filter_statuses, "prod_env": prod_env, - "validator_endpoint": validator_endpoint, + "force_update": force_update, + "limit": limit, }, } logging.info(result) return result +def _get_validator_version(validator_endpoint: str) -> str: + """Fetch the current version string from the validator service.""" + response = requests.get(f"{validator_endpoint}/version", timeout=30) + response.raise_for_status() + version = response.json()["version"] + logging.info("Validator version: %s", version) + return version + + +def _get_datasets_for_validation( + db_session: Session, + validator_version: str, + force_update: bool, + filter_after_in_days: Optional[int], + filter_statuses: Optional[List[str]], + filter_op_statuses: Optional[List[str]], +) -> List[tuple]: + """ + Query datasets that need a (re)validation. + + Includes datasets that: + - Have no validation report at all, OR + - Have a report from a different (older) validator version, OR + - force_update is True + + filter_after_in_days restricts to datasets downloaded within the last N days. + When None, all datasets are included regardless of age. + filter_op_statuses filters by Feed.operational_status (e.g. ["published"]). + """ + query = ( + db_session.query(Gtfsfeed.stable_id, Gtfsdataset.stable_id) + .select_from(Gtfsfeed) + .join(Gtfsdataset, Gtfsfeed.latest_dataset_id == Gtfsdataset.id) + .outerjoin(Validationreport, Gtfsdataset.validation_reports) + .filter( + or_( + Validationreport.id.is_(None), + Validationreport.validator_version != validator_version, + force_update, + ) + ) + .distinct(Gtfsfeed.stable_id, Gtfsdataset.stable_id) + .order_by(Gtfsdataset.stable_id, Gtfsfeed.stable_id) + ) + if filter_after_in_days is not None: + filter_after = datetime.today() - timedelta(days=filter_after_in_days) + query = query.filter(Gtfsdataset.downloaded_at >= filter_after) + if filter_statuses: + query = query.filter(Gtfsfeed.status.in_(filter_statuses)) + if filter_op_statuses: + query = query.filter(Feed.operational_status.in_(filter_op_statuses)) + + return query.all() + + +def _filter_out_datasets_without_blob( + datasets: List[tuple], + limit: Optional[int] = None, +) -> List[tuple]: + """ + Filter out datasets whose zip file does not exist in GCS. + This avoids triggering workflows for feeds that have no data to validate. + + When limit is provided, stops as soon as limit valid datasets are found, + avoiding unnecessary GCS calls for the remaining candidates. + """ + storage_client = storage.Client() + bucket = storage_client.bucket(datasets_bucket_name) + valid = [] + for feed_id, dataset_id in datasets: + if limit is not None and len(valid) >= limit: + break + try: + blob = bucket.blob(f"{feed_id}/{dataset_id}/{dataset_id}.zip") + if blob.exists(): + valid.append((feed_id, dataset_id)) + else: + logging.warning( + "No GCS blob found for %s/%s — skipping", feed_id, dataset_id + ) + except Exception as e: + logging.error( + "Error checking GCS blob for %s/%s: %s", feed_id, dataset_id, e + ) + return valid + + def get_parameters(payload): """ - Get parameters from the payload and environment variables. + Extract and coerce parameters from the task payload. Args: - payload (dict): dictionary containing the payload data. + payload (dict): Task payload dict. Returns: - dict: dict with: dry_run, filter_after_in_days, filter_statuses, prod_env, validator_endpoint parameters + Tuple of (dry_run, filter_after_in_days, filter_statuses, prod_env, + validator_endpoint, force_update, limit) """ prod_env = os.getenv("ENV", "").lower() == "prod" - validator_endpoint = get_gtfs_validator_url(prod_env) + default_endpoint = get_gtfs_validator_url(prod_env) + dry_run = payload.get("dry_run", True) dry_run = dry_run if isinstance(dry_run, bool) else str(dry_run).lower() == "true" - filter_after_in_days = payload.get("filter_after_in_days", 7) - filter_after_in_days = ( - filter_after_in_days - if isinstance(filter_after_in_days, int) - else int(filter_after_in_days) - ) + + filter_after_in_days = payload.get("filter_after_in_days", None) + if filter_after_in_days is not None: + filter_after_in_days = ( + filter_after_in_days + if isinstance(filter_after_in_days, int) + else int(filter_after_in_days) + ) + filter_statuses = payload.get("filter_statuses", None) - return dry_run, filter_after_in_days, filter_statuses, prod_env, validator_endpoint + + filter_op_statuses = payload.get("filter_op_statuses", None) + + validator_endpoint = payload.get("validator_endpoint", default_endpoint) + + bypass_db_update = payload.get("bypass_db_update", False) + bypass_db_update = ( + bypass_db_update + if isinstance(bypass_db_update, bool) + else str(bypass_db_update).lower() == "true" + ) + + force_update = payload.get("force_update", False) + force_update = ( + force_update + if isinstance(force_update, bool) + else str(force_update).lower() == "true" + ) + + limit = payload.get("limit", None) + if limit is not None: + limit = int(limit) + + return ( + dry_run, + filter_after_in_days, + filter_statuses, + filter_op_statuses, + prod_env, + validator_endpoint, + bypass_db_update, + force_update, + limit, + ) diff --git a/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py b/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py index b386d02dc..43677d7ba 100644 --- a/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py +++ b/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py @@ -24,7 +24,7 @@ from sqlalchemy.orm import Session from shared.database.database import with_db_session -from shared.helpers.tests.test_shared.test_utils.database_utils import default_db_url +from test_shared.test_utils.database_utils import default_db_url from tasks.dataset_files.rebuild_missing_dataset_files import ( rebuild_missing_dataset_files, rebuild_missing_dataset_files_handler, diff --git a/functions-python/tasks_executor/tests/tasks/test_get_task_run_status.py b/functions-python/tasks_executor/tests/tasks/test_get_task_run_status.py new file mode 100644 index 000000000..7496208a2 --- /dev/null +++ b/functions-python/tasks_executor/tests/tasks/test_get_task_run_status.py @@ -0,0 +1,150 @@ +# +# MobilityData 2026 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from unittest.mock import MagicMock, patch + +_MODULE = "tasks.get_task_run_status" + + +class TestGetTaskRunStatusHandler(unittest.TestCase): + def test_requires_task_name(self): + from tasks.get_task_run_status import get_task_run_status_handler + + with self.assertRaises(ValueError): + get_task_run_status_handler({"run_id": "7.0.0"}) + + def test_requires_run_id(self): + from tasks.get_task_run_status import get_task_run_status_handler + + with self.assertRaises(ValueError): + get_task_run_status_handler({"task_name": "gtfs_validation"}) + + @patch(f"{_MODULE}.get_task_run_status") + def test_passes_params(self, mock_fn): + from tasks.get_task_run_status import get_task_run_status_handler + + mock_fn.return_value = {"run_status": "completed"} + get_task_run_status_handler({"task_name": "gtfs_validation", "run_id": "7.0.0"}) + mock_fn.assert_called_once_with(task_name="gtfs_validation", run_id="7.0.0") + + +class TestGetTaskRunStatus(unittest.TestCase): + def _make_summary(self, overrides=None): + base = { + "task_name": "gtfs_validation", + "run_id": "7.0.0", + "run_status": "in_progress", + "total_count": 100, + "triggered": 10, + "completed": 80, + "failed": 0, + "pending": 10, + "created_at": None, + "params": None, + } + if overrides: + base.update(overrides) + return base + + def _make_session_mock(self): + return MagicMock() + + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_returns_summary_with_dispatch_complete_false(self, tracker_cls): + from tasks.get_task_run_status import get_task_run_status + + tracker = MagicMock() + tracker.get_summary.return_value = self._make_summary({"pending": 10}) + tracker_cls.return_value = tracker + session = self._make_session_mock() + + result = get_task_run_status( + task_name="gtfs_validation", run_id="7.0.0", db_session=session + ) + + self.assertFalse(result["dispatch_complete"]) + self.assertEqual(result["pending"], 10) + tracker.finish_run.assert_not_called() + + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_returns_summary_with_dispatch_complete_true(self, tracker_cls): + from tasks.get_task_run_status import get_task_run_status + + tracker = MagicMock() + tracker.get_summary.return_value = self._make_summary( + {"pending": 0, "triggered": 0, "completed": 100} + ) + tracker_cls.return_value = tracker + session = self._make_session_mock() + + result = get_task_run_status( + task_name="gtfs_validation", run_id="7.0.0", db_session=session + ) + + self.assertTrue(result["dispatch_complete"]) + self.assertEqual(result["completed"], 100) + tracker.finish_run.assert_not_called() + + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_does_not_modify_statuses(self, tracker_cls): + from tasks.get_task_run_status import get_task_run_status + + tracker = MagicMock() + tracker.get_summary.return_value = self._make_summary() + tracker_cls.return_value = tracker + session = self._make_session_mock() + + get_task_run_status( + task_name="gtfs_validation", run_id="7.0.0", db_session=session + ) + + tracker.mark_completed.assert_not_called() + tracker.mark_failed.assert_not_called() + tracker.mark_triggered.assert_not_called() + tracker.finish_run.assert_not_called() + tracker.schedule_status_sync.assert_not_called() + + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_returns_none_run_status_when_not_found(self, tracker_cls): + from tasks.get_task_run_status import get_task_run_status + + tracker = MagicMock() + tracker.get_summary.return_value = { + "task_name": "gtfs_validation", + "run_id": "9.9.9", + "run_status": None, + "total_count": None, + "triggered": 0, + "completed": 0, + "failed": 0, + "pending": 0, + "created_at": None, + "params": None, + } + tracker_cls.return_value = tracker + session = self._make_session_mock() + + result = get_task_run_status( + task_name="gtfs_validation", run_id="9.9.9", db_session=session + ) + + self.assertIsNone(result["run_status"]) + self.assertTrue(result["dispatch_complete"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/functions-python/tasks_executor/tests/tasks/test_sync_task_run_status.py b/functions-python/tasks_executor/tests/tasks/test_sync_task_run_status.py new file mode 100644 index 000000000..08061d9d8 --- /dev/null +++ b/functions-python/tasks_executor/tests/tasks/test_sync_task_run_status.py @@ -0,0 +1,173 @@ +# +# MobilityData 2026 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from unittest.mock import MagicMock, patch + +from shared.helpers.task_execution.task_execution_tracker import TaskInProgressError + +_MODULE = "tasks.sync_task_run_status" + + +class TestSyncTaskRunStatusHandler(unittest.TestCase): + def test_requires_task_name(self): + from tasks.sync_task_run_status import sync_task_run_status_handler + + with self.assertRaises(ValueError): + sync_task_run_status_handler({"run_id": "7.0.0"}) + + def test_requires_run_id(self): + from tasks.sync_task_run_status import sync_task_run_status_handler + + with self.assertRaises(ValueError): + sync_task_run_status_handler({"task_name": "gtfs_validation"}) + + @patch(f"{_MODULE}.sync_task_run_status") + def test_passes_params(self, mock_fn): + from tasks.sync_task_run_status import sync_task_run_status_handler + + mock_fn.return_value = {"run_status": "completed"} + sync_task_run_status_handler( + {"task_name": "gtfs_validation", "run_id": "7.0.0"} + ) + mock_fn.assert_called_once_with(task_name="gtfs_validation", run_id="7.0.0") + + +class TestSyncTaskRunStatus(unittest.TestCase): + def _make_tracker_mock(self, summary): + tracker = MagicMock() + tracker.get_summary.return_value = summary + return tracker + + def _make_session_mock(self): + session = MagicMock() + session.query.return_value.filter.return_value.all.return_value = [] + return session + + @patch(f"{_MODULE}._sync_workflow_statuses") + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_marks_completed_when_all_settled(self, tracker_cls, sync_mock): + from tasks.sync_task_run_status import sync_task_run_status + + tracker = self._make_tracker_mock( + { + "run_status": "in_progress", + "total_count": 5, + "pending": 0, + "triggered": 0, + "completed": 5, + "failed": 0, + "params": {"total_candidates": 5}, + } + ) + tracker_cls.return_value = tracker + session = self._make_session_mock() + + result = sync_task_run_status( + task_name="gtfs_validation", run_id="7.0.0", db_session=session + ) + + tracker.finish_run.assert_called_once() + tracker.schedule_status_sync.assert_not_called() + self.assertTrue(result["ready_for_bigquery"]) + self.assertTrue(result["dispatch_complete"]) + + @patch(f"{_MODULE}._sync_workflow_statuses") + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_raises_task_in_progress_when_triggered_remain( + self, tracker_cls, sync_mock + ): + from tasks.sync_task_run_status import sync_task_run_status + + tracker = self._make_tracker_mock( + { + "run_status": "in_progress", + "total_count": 100, + "pending": 0, + "triggered": 40, + "completed": 60, + "failed": 0, + "params": {"total_candidates": 100}, + } + ) + tracker_cls.return_value = tracker + session = self._make_session_mock() + + with self.assertRaises(TaskInProgressError): + sync_task_run_status( + task_name="gtfs_validation", run_id="7.0.0", db_session=session + ) + + tracker.finish_run.assert_not_called() + tracker.schedule_status_sync.assert_not_called() + + @patch(f"{_MODULE}._sync_workflow_statuses") + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_raises_task_in_progress_when_dispatch_incomplete( + self, tracker_cls, sync_mock + ): + from tasks.sync_task_run_status import sync_task_run_status + + tracker = self._make_tracker_mock( + { + "run_status": "in_progress", + "total_count": 100, + "pending": 30, + "triggered": 0, + "completed": 70, + "failed": 0, + "params": None, + } + ) + tracker_cls.return_value = tracker + session = self._make_session_mock() + + with self.assertRaises(TaskInProgressError): + sync_task_run_status( + task_name="gtfs_validation", run_id="7.0.0", db_session=session + ) + + @patch(f"{_MODULE}._sync_workflow_statuses") + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_raises_task_in_progress_when_failures_exist(self, tracker_cls, sync_mock): + from tasks.sync_task_run_status import sync_task_run_status + + tracker = self._make_tracker_mock( + { + "run_status": "in_progress", + "total_count": 5, + "pending": 0, + "triggered": 0, + "completed": 3, + "failed": 2, + "params": None, + } + ) + tracker_cls.return_value = tracker + session = self._make_session_mock() + session.query.return_value.filter.return_value.all.return_value = [ + MagicMock(entity_id="ds-1"), + MagicMock(entity_id="ds-2"), + ] + + with self.assertRaises(TaskInProgressError): + sync_task_run_status( + task_name="gtfs_validation", run_id="7.0.0", db_session=session + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/functions-python/tasks_executor/tests/tasks/validation_reports/test_rebuild_missing_validation_reports.py b/functions-python/tasks_executor/tests/tasks/validation_reports/test_rebuild_missing_validation_reports.py index a199476c5..c087474df 100644 --- a/functions-python/tasks_executor/tests/tasks/validation_reports/test_rebuild_missing_validation_reports.py +++ b/functions-python/tasks_executor/tests/tasks/validation_reports/test_rebuild_missing_validation_reports.py @@ -17,225 +17,293 @@ import unittest from unittest.mock import patch, MagicMock -from sqlalchemy.orm import Session - -from shared.database.database import with_db_session -from shared.database_gen.sqlacodegen_models import Gtfsdataset, Feed from shared.helpers.gtfs_validator_common import GTFS_VALIDATOR_URL_STAGING -from shared.helpers.tests.test_shared.test_utils.database_utils import default_db_url from tasks.validation_reports.rebuild_missing_validation_reports import ( rebuild_missing_validation_reports_handler, get_parameters, rebuild_missing_validation_reports, ) +_MODULE = "tasks.validation_reports.rebuild_missing_validation_reports" -class TestTasksExecutor(unittest.TestCase): - def test_get_parameters(self): - """ - Test the get_parameters function to ensure it correctly extracts parameters from the payload. - """ - payload = { - "dry_run": True, - "filter_after_in_days": 14, - "filter_statuses": ["status1", "status2"], - } +class TestGetParameters(unittest.TestCase): + def test_defaults(self): ( dry_run, filter_after_in_days, filter_statuses, + filter_op_statuses, prod_env, validator_endpoint, - ) = get_parameters(payload) - + bypass_db_update, + force_update, + limit, + ) = get_parameters({}) self.assertTrue(dry_run) - self.assertEqual(filter_after_in_days, 14) - self.assertEqual(filter_statuses, ["status1", "status2"]) + self.assertIsNone(filter_after_in_days) + self.assertIsNone(filter_statuses) + self.assertIsNone(filter_op_statuses) self.assertFalse(prod_env) self.assertEqual(validator_endpoint, GTFS_VALIDATOR_URL_STAGING) + self.assertFalse(bypass_db_update) + self.assertFalse(force_update) + self.assertIsNone(limit) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.rebuild_missing_validation_reports" - ) - def test_rebuild_missing_validation_reports_entry( - self, rebuild_missing_validation_reports_mock - ): - """ - Test the rebuild_missing_validation_reports_entry function. - Assert that it correctly calls the rebuild_missing_validation_reports function with the expected parameters. - """ - # Mock payload for the test + def test_all_params(self): payload = { - "dry_run": True, - "filter_after_in_days": 14, - "filter_statuses": ["status1", "status2"], + "dry_run": False, + "filter_after_in_days": 30, + "filter_statuses": ["active"], + "filter_op_statuses": ["published", "unpublished"], + "validator_endpoint": "https://staging.example.com/api", + "bypass_db_update": True, + "force_update": True, + "limit": 10, } - expected_response = MagicMock() - rebuild_missing_validation_reports_mock.return_value = expected_response - response = rebuild_missing_validation_reports_handler(payload) + ( + dry_run, + filter_after_in_days, + filter_statuses, + filter_op_statuses, + prod_env, + validator_endpoint, + bypass_db_update, + force_update, + limit, + ) = get_parameters(payload) + self.assertFalse(dry_run) + self.assertEqual(filter_after_in_days, 30) + self.assertEqual(filter_statuses, ["active"]) + self.assertEqual(filter_op_statuses, ["published", "unpublished"]) + self.assertEqual(validator_endpoint, "https://staging.example.com/api") + self.assertTrue(bypass_db_update) + self.assertTrue(force_update) + self.assertEqual(limit, 10) - self.assertEqual(response, expected_response) - rebuild_missing_validation_reports_mock.assert_called_once_with( - validator_endpoint=GTFS_VALIDATOR_URL_STAGING, - dry_run=True, - filter_after_in_days=14, - filter_statuses=["status1", "status2"], - prod_env=False, + def test_string_coercion(self): + payload = { + "dry_run": "false", + "bypass_db_update": "true", + "force_update": "true", + "limit": "5", + } + dry_run, _, _, _, _, _, bypass_db_update, force_update, limit = get_parameters( + payload ) + self.assertFalse(dry_run) + self.assertTrue(bypass_db_update) + self.assertTrue(force_update) + self.assertEqual(limit, 5) + - @with_db_session(db_url=default_db_url) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.execute_workflows", - ) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.QUERY_LIMIT", 10 - ) - def test_rebuild_missing_validation_reports_one_page( - self, execute_workflows_mock, db_session: Session +class TestRebuildMissingValidationReports(unittest.TestCase): + def _make_session_mock(self, datasets=None): + """Create a mock DB session returning given datasets from the query.""" + session = MagicMock() + query_mock = MagicMock() + query_mock.select_from.return_value = query_mock + query_mock.join.return_value = query_mock + query_mock.outerjoin.return_value = query_mock + query_mock.filter.return_value = query_mock + query_mock.distinct.return_value = query_mock + query_mock.order_by.return_value = query_mock + query_mock.all.return_value = datasets or [] + session.query.return_value = query_mock + return session + + @patch(f"{_MODULE}._get_validator_version", return_value="7.0.0") + @patch(f"{_MODULE}._filter_out_datasets_without_blob", return_value=[]) + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_dry_run_returns_count_without_triggering( + self, tracker_cls, filter_blob_mock, version_mock ): - """ - Test the rebuild_missing_validation_reports function with a single page of results. - We are assuming tha the dataset has 7 datasets, and the query limit is set to 10. - """ - execute_workflows_mock.return_value = [] - response = rebuild_missing_validation_reports( - db_session=db_session, - validator_endpoint="https://i_dont.exists.com", - dry_run=False, - prod_env=False, + session = self._make_session_mock( + datasets=[("feed-1", "ds-1"), ("feed-2", "ds-2")] ) + filter_blob_mock.return_value = [("feed-1", "ds-1"), ("feed-2", "ds-2")] + tracker_cls.return_value.count_already_tracked.return_value = 1 - # Assert the expected behavior - self.assertIsNotNone(response) - self.assertEquals(response["total_processed"], 9) - self.assertEquals( - response["message"], - "Rebuild missing validation reports task executed successfully.", + result = rebuild_missing_validation_reports( + validator_endpoint="https://staging.example.com/api", + dry_run=True, + db_session=session, + ) + + self.assertEqual(result["total_candidates"], 2) + self.assertEqual(result["total_triggered"], 0) + self.assertEqual(result["total_already_tracked"], 1) + self.assertIn("dry_run", result["params"]) + self.assertTrue(result["params"]["dry_run"]) + tracker_cls.return_value.start_run.assert_not_called() + tracker_cls.return_value.count_already_tracked.assert_called_once_with( + ["ds-1", "ds-2"] ) - execute_workflows_mock.assert_called_once() - - @with_db_session(db_url=default_db_url) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.execute_workflows", - ) - @patch("tasks.validation_reports.rebuild_missing_validation_reports.QUERY_LIMIT", 2) - def test_rebuild_missing_validation_reports_two_pages( - self, execute_workflows_mock, db_session: Session + + @patch(f"{_MODULE}._get_validator_version", return_value="7.0.0") + @patch(f"{_MODULE}._filter_out_datasets_without_blob") + @patch(f"{_MODULE}.execute_workflows", return_value=["ds-1", "ds-2"]) + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_triggers_workflows_when_not_dry_run( + self, tracker_cls, exec_mock, filter_blob_mock, version_mock ): - """ - Test the rebuild_missing_validation_reports function with a single page of results. - We are assuming tha the dataset has 7 datasets, and the query limit is set to 2. - """ - execute_workflows_mock.return_value = [] - response = rebuild_missing_validation_reports( - db_session=db_session, - validator_endpoint="https://i_dont.exists.com", + datasets = [("feed-1", "ds-1"), ("feed-2", "ds-2")] + filter_blob_mock.return_value = datasets + session = self._make_session_mock(datasets=datasets) + + result = rebuild_missing_validation_reports( + validator_endpoint="https://staging.example.com/api", dry_run=False, - prod_env=False, + db_session=session, ) - # Assert the expected behavior - self.assertIsNotNone(response) - self.assertEquals(response["total_processed"], 9) - self.assertEquals( - response["message"], - "Rebuild missing validation reports task executed successfully.", + exec_mock.assert_called_once() + self.assertEqual(result["total_triggered"], 2) + self.assertFalse(result["params"]["dry_run"]) + + @patch(f"{_MODULE}._get_validator_version", return_value="7.0.0") + @patch(f"{_MODULE}._filter_out_datasets_without_blob") + @patch(f"{_MODULE}.execute_workflows", return_value=["ds-1"]) + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_limit_slices_datasets( + self, tracker_cls, exec_mock, filter_blob_mock, version_mock + ): + datasets = [(f"feed-{i}", f"ds-{i}") for i in range(20)] + filter_blob_mock.return_value = [(f"feed-{i}", f"ds-{i}") for i in range(5)] + session = self._make_session_mock(datasets=datasets) + + result = rebuild_missing_validation_reports( + validator_endpoint="https://staging.example.com/api", + dry_run=False, + limit=5, + db_session=session, ) - self.assertEquals(execute_workflows_mock.call_count, 5) - - @with_db_session(db_url=default_db_url) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.execute_workflows", - ) - @patch("tasks.validation_reports.rebuild_missing_validation_reports.QUERY_LIMIT", 2) - def test_rebuild_missing_validation_reports_dryrun( - self, execute_workflows_mock, db_session: Session + + # blob filter must be called with full candidate list AND the limit + filter_blob_mock.assert_called_once_with(datasets, limit=5) + triggered_datasets = exec_mock.call_args[0][0] + self.assertEqual(len(triggered_datasets), 5) + self.assertEqual(result["total_candidates"], 20) + self.assertEqual(result["total_in_call"], 5) + + @patch(f"{_MODULE}._get_validator_version", return_value="7.0.0") + @patch(f"{_MODULE}._filter_out_datasets_without_blob", return_value=[("f", "ds-1")]) + @patch(f"{_MODULE}.execute_workflows", return_value=["ds-1"]) + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_bypass_db_update_passed_explicitly( + self, tracker_cls, exec_mock, filter_blob_mock, version_mock ): - """ - Test the rebuild_missing_validation_reports function with a single page of results. - We are assuming tha the dataset has 7 datasets, and the query limit is set to 2. - """ - execute_workflows_mock.return_value = [] - response = rebuild_missing_validation_reports( - db_session=db_session, - validator_endpoint="https://i_dont.exists.com", - dry_run=True, - prod_env=False, + session = self._make_session_mock(datasets=[("f", "ds-1")]) + rebuild_missing_validation_reports( + validator_endpoint="https://staging.example.com/api", + bypass_db_update=True, + dry_run=False, + db_session=session, ) + _, call_kwargs = exec_mock.call_args + self.assertTrue(call_kwargs["bypass_db_update"]) - # Assert the expected behavior - self.assertIsNotNone(response) - self.assertEquals(response["total_processed"], 9) - self.assertEquals(response["message"], "Dry run: no datasets processed.") - execute_workflows_mock.assert_not_called() - - @with_db_session(db_url=default_db_url) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.execute_workflows", - ) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.QUERY_LIMIT", 10 - ) - def test_rebuild_missing_validation_reports_filter_active( - self, execute_workflows_mock, db_session: Session + @patch(f"{_MODULE}._get_validator_version", return_value="7.0.0") + @patch(f"{_MODULE}._filter_out_datasets_without_blob", return_value=[("f", "ds-1")]) + @patch(f"{_MODULE}.execute_workflows", return_value=["ds-1"]) + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_bypass_db_update_defaults_to_false( + self, tracker_cls, exec_mock, filter_blob_mock, version_mock ): - """ - Test the rebuild_missing_validation_reports function with a single page of results. - We are assuming tha the dataset has 7 datasets, and the query limit is set to 2. - """ - active_counter = ( - db_session.query(Gtfsdataset) - .join(Gtfsdataset.feed) - .filter(Feed.status == "active") - .count() + session = self._make_session_mock(datasets=[("f", "ds-1")]) + rebuild_missing_validation_reports( + validator_endpoint="https://staging.example.com/api", + dry_run=False, + db_session=session, ) - execute_workflows_mock.return_value = [] - response = rebuild_missing_validation_reports( - db_session=db_session, - validator_endpoint="https://i_dont.exists.com", + _, call_kwargs = exec_mock.call_args + self.assertFalse(call_kwargs["bypass_db_update"]) + + @patch(f"{_MODULE}.rebuild_missing_validation_reports") + def test_handler_passes_all_params(self, rebuild_mock): + rebuild_mock.return_value = {"message": "ok"} + payload = { + "dry_run": False, + "filter_after_in_days": 30, + "validator_endpoint": "https://staging.example.com/api", + "force_update": True, + "limit": 10, + "filter_op_statuses": ["published", "wip"], + } + rebuild_missing_validation_reports_handler(payload) + rebuild_mock.assert_called_once_with( + validator_endpoint="https://staging.example.com/api", + bypass_db_update=False, dry_run=False, + filter_after_in_days=30, + filter_statuses=None, + filter_op_statuses=["published", "wip"], prod_env=False, - filter_statuses=["active"], + force_update=True, + limit=10, ) - # Assert the expected behavior - self.assertIsNotNone(response) - self.assertEquals(response["total_processed"], active_counter) - self.assertEquals( - response["message"], - "Rebuild missing validation reports task executed successfully.", - ) - self.assertEquals(execute_workflows_mock.call_count, 1) - - @with_db_session(db_url=default_db_url) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.execute_workflows", - ) - @patch( - "tasks.validation_reports.rebuild_missing_validation_reports.QUERY_LIMIT", 10 - ) - def test_rebuild_missing_validation_reports_filter_no_results( - self, execute_workflows_mock, db_session: Session + @patch(f"{_MODULE}._get_validator_version", return_value="7.0.0") + @patch(f"{_MODULE}._filter_out_datasets_without_blob", return_value=[]) + @patch(f"{_MODULE}.TaskExecutionTracker") + def test_default_op_status_filters_published( + self, tracker_cls, filter_blob_mock, version_mock ): - """ - Test the rebuild_missing_validation_reports function with a single page of results. - We are assuming tha the dataset has 7 datasets, and the query limit is set to 2. - """ - execute_workflows_mock.return_value = [] - response = rebuild_missing_validation_reports( - db_session=db_session, - validator_endpoint="https://i_dont.exists.com", - dry_run=False, - prod_env=False, - filter_statuses=["future"], + """When filter_op_statuses is None, the query should default to ['published'].""" + session = self._make_session_mock(datasets=[]) + rebuild_missing_validation_reports( + validator_endpoint="https://staging.example.com/api", + dry_run=True, + filter_op_statuses=None, + db_session=session, ) + # The query chain should have received a filter call for operational_status + # Verify via the query mock that .filter was called (default published applied) + self.assertTrue(session.query.called) - # Assert the expected behavior - self.assertIsNotNone(response) - self.assertEquals(response["total_processed"], 0) - self.assertEquals( - response["message"], - "Rebuild missing validation reports task executed successfully.", + +class TestFilterDatasetsWithExistingBlob(unittest.TestCase): + @patch(f"{_MODULE}.storage") + def test_stops_at_limit(self, storage_mock): + """Should stop checking GCS as soon as limit valid datasets are found.""" + from tasks.validation_reports.rebuild_missing_validation_reports import ( + _filter_out_datasets_without_blob, ) - execute_workflows_mock.assert_not_called() + + bucket_mock = MagicMock() + storage_mock.Client.return_value.bucket.return_value = bucket_mock + + # All blobs exist + blob_mock = MagicMock() + blob_mock.exists.return_value = True + bucket_mock.blob.return_value = blob_mock + + datasets = [(f"feed-{i}", f"ds-{i}") for i in range(20)] + result = _filter_out_datasets_without_blob(datasets, limit=3) + + self.assertEqual(len(result), 3) + # Only 3 GCS calls should have been made + self.assertEqual(blob_mock.exists.call_count, 3) + + @patch(f"{_MODULE}.storage") + def test_skips_missing_blobs_and_continues(self, storage_mock): + """Should skip datasets with no blob and keep going until limit is reached.""" + from tasks.validation_reports.rebuild_missing_validation_reports import ( + _filter_out_datasets_without_blob, + ) + + bucket_mock = MagicMock() + storage_mock.Client.return_value.bucket.return_value = bucket_mock + + # First 2 blobs missing, next 3 exist + exists_sequence = [False, False, True, True, True, True, True] + blob_mock = MagicMock() + blob_mock.exists.side_effect = exists_sequence + bucket_mock.blob.return_value = blob_mock + + datasets = [(f"feed-{i}", f"ds-{i}") for i in range(7)] + result = _filter_out_datasets_without_blob(datasets, limit=3) + + self.assertEqual(len(result), 3) + # Must have checked 5 items: 2 missing + 3 valid + self.assertEqual(blob_mock.exists.call_count, 5) diff --git a/functions-python/tasks_executor/tests/tasks/visualisation_files/test_rebuild_visualization_files.py b/functions-python/tasks_executor/tests/tasks/visualisation_files/test_rebuild_visualization_files.py index 8f2731cb5..dbdd63176 100644 --- a/functions-python/tasks_executor/tests/tasks/visualisation_files/test_rebuild_visualization_files.py +++ b/functions-python/tasks_executor/tests/tasks/visualisation_files/test_rebuild_visualization_files.py @@ -22,7 +22,7 @@ from shared.database.database import with_db_session from shared.database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfile, Gtfsfeed -from shared.helpers.tests.test_shared.test_utils.database_utils import default_db_url +from test_shared.test_utils.database_utils import default_db_url from tasks.visualization_files.rebuild_missing_visualization_files import ( rebuild_missing_visualization_files_handler, diff --git a/functions-python/update_validation_report/.coveragerc b/functions-python/update_validation_report/.coveragerc deleted file mode 100644 index d9fb9a847..000000000 --- a/functions-python/update_validation_report/.coveragerc +++ /dev/null @@ -1,10 +0,0 @@ -[run] -omit = - */test*/* - */helpers/* - */database_gen/* - */shared/* - -[report] -exclude_lines = - if __name__ == .__main__.: \ No newline at end of file diff --git a/functions-python/update_validation_report/.env.rename_me b/functions-python/update_validation_report/.env.rename_me deleted file mode 100644 index be6da9c89..000000000 --- a/functions-python/update_validation_report/.env.rename_me +++ /dev/null @@ -1,7 +0,0 @@ -# Environment variables for the validation report updates to run locally -FEEDS_DATABASE_URL={{FEEDS_DATABASE_URL}} -ENV={{ENV}} -BATCH_SIZE={{BATCH_SIZE}} -WEB_VALIDATOR_URL={{WEB_VALIDATOR_URL}} -LOCATION={{LOCATION}} -SLEEP_TIME={{SLEEP_TIME}} diff --git a/functions-python/update_validation_report/README.md b/functions-python/update_validation_report/README.md deleted file mode 100644 index e32903272..000000000 --- a/functions-python/update_validation_report/README.md +++ /dev/null @@ -1,40 +0,0 @@ -Here's a more polished version of the description: - -# Update Validation Report - -This function initiates the process of updating the validation report for all the latest datasets that do not yet have a report generated with the current version. - -## Function Parameters - -To support flexibility in handling different snapshots and validator versions, the following parameters can be used to customize the function's behavior: - -- `validator_endpoint`: Specifies the endpoint of the validator to be used for the validation process. -- `force_update`: Forces an update by ignoring existing validation reports of the same version, treating them as if they do not exist. -- `env`: Specifies the environment (`stagging` or `prod`), used to determine the appropriate bucket name and project id for retrieving validation reports and executing the `gtfs_validator_execution` workflow. - -## Function Workflow -1. **HTTP Request Trigger**: The function is initiated via an HTTP request. -2. **Retrieve Latest Datasets**: Retrieves the latest datasets from the database that do not have the latest version of the validation report. -3. **Validate Accessibility of Datasets**: Checks the availability of the latest datasets to ensure that the data is accessible for validation report processing. -4. **Trigger Validation Report Processing**: If the latest dataset lacks the current validation report, this action initiates the `gtfs_validator_execution` workflow. -5. **Return Response**: Outputs a response indicating the status of the validation report update. The response format is as follows: -```json -{ - "message": "Validation report update needed for X datasets and triggered for Y datasets", - "dataset_workflow_triggered": ["dataset_id1", "dataset_id2", ...], - "datasets_not_updated": ["dataset_id3", "dataset_id4", ...] - "ignored_datasets": ["dataset_id5", "dataset_id6", ...] -} -``` -The response message provides information on the number of datasets that require a validation report update and the number of datasets for which the update has been triggered. It also lists the datasets that were not updated and those that were ignored due to unavailability of the data. - -## Function Configuration -The function relies on several environmental variables: -- `FEEDS_DATABASE_URL`: URL used to connect to the database that holds GTFS datasets and related data. -- `ENV`: Specifies the environment (`dev`, `qa`, or `prod`), used to determine the appropriate bucket name and project id for retrieving validation reports and executing the `gtfs_validator_execution` workflow. -- `BATCH_SIZE`: Number of datasets processed in each batch to prevent rate limiting by the web validator. -- `SLEEP_TIME`: Time in seconds to wait between batches to prevent rate limiting by the web validator. -- `WEB_VALIDATOR_URL`: URL for the web validator that checks for the latest validation report version. -- `LOCATION`: Location of the GCP workflow execution. -## Local Development -Follow standard practices for local development of GCP serverless functions. Refer to the main [README.md](../README.md) for general setup instructions for the development environment. diff --git a/functions-python/update_validation_report/function_config.json b/functions-python/update_validation_report/function_config.json deleted file mode 100644 index 9109da7c8..000000000 --- a/functions-python/update_validation_report/function_config.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "name": "update-validation-report", - "description": "Batch update of validation report for the latest datasets", - "entry_point": "update_validation_report", - "timeout": 3600, - "memory": "256Mi", - "trigger_http": true, - "include_folders": ["helpers"], - "include_api_folders": ["database_gen", "database", "common"], - "secret_environment_variables": [ - { - "key": "FEEDS_DATABASE_URL" - } - ], - "ingress_settings": "ALLOW_INTERNAL_AND_GCLB", - "max_instance_request_concurrency": 1, - "max_instance_count": 1, - "min_instance_count": 0, - "available_cpu": 1 -} diff --git a/functions-python/update_validation_report/requirements.txt b/functions-python/update_validation_report/requirements.txt deleted file mode 100644 index 2409548f8..000000000 --- a/functions-python/update_validation_report/requirements.txt +++ /dev/null @@ -1,23 +0,0 @@ -# Common packages -functions-framework==3.* -google-cloud-logging -psycopg2-binary==2.9.6 -aiohttp~=3.10.5 -asyncio~=3.4.3 -urllib3~=2.5.0 -requests~=2.32.3 -attrs~=23.1.0 -pluggy~=1.3.0 -certifi~=2025.8.3 - -# SQL Alchemy and Geo Alchemy -SQLAlchemy==2.0.23 -geoalchemy2==0.14.7 - -# Google specific packages for this function -cloudevents~=1.10.1 -google-cloud-storage -google-cloud-workflows - -# Configuration -python-dotenv==1.0.0 \ No newline at end of file diff --git a/functions-python/update_validation_report/requirements_dev.txt b/functions-python/update_validation_report/requirements_dev.txt deleted file mode 100644 index 9ee50adce..000000000 --- a/functions-python/update_validation_report/requirements_dev.txt +++ /dev/null @@ -1,2 +0,0 @@ -Faker -pytest~=7.4.3 \ No newline at end of file diff --git a/functions-python/update_validation_report/src/main.py b/functions-python/update_validation_report/src/main.py deleted file mode 100644 index b914d4d90..000000000 --- a/functions-python/update_validation_report/src/main.py +++ /dev/null @@ -1,180 +0,0 @@ -# -# MobilityData 2024 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import os -from typing import List - -import flask -import functions_framework -import requests -import sqlalchemy.orm -from sqlalchemy import or_ -from sqlalchemy.orm import Session -from google.cloud import storage -from sqlalchemy.engine import Row -from sqlalchemy.engine.interfaces import Any - -from shared.database_gen.sqlacodegen_models import ( - Gtfsdataset, - Gtfsfeed, - Validationreport, -) -from shared.helpers.gtfs_validator_common import get_gtfs_validator_results_bucket -from shared.database.database import with_db_session - -from shared.helpers.logger import init_logger -from shared.helpers.validation_report.validation_report_update import execute_workflows - -init_logger() -env = os.getenv("ENV", "dev").lower() -bucket_name = f"mobilitydata-datasets-{env}" - - -@with_db_session -@functions_framework.http -def update_validation_report(request: flask.Request, db_session: Session): - """ - Update the validation report for the datasets that need it - """ - request_json = request.get_json() - validator_endpoint = request_json.get( - "validator_endpoint", os.getenv("WEB_VALIDATOR_URL") - ) - bypass_db_update = validator_endpoint != os.getenv("WEB_VALIDATOR_URL") - force_update = request_json.get("force_update", False) - - # Check if the environment parameter is valid and set the reports bucket name - env_param = request_json.get("env", None) - reports_bucket_name = None - if env_param: - if env_param.lower() not in ["staging", "prod"]: - return { - "message": "Invalid environment parameter. Allowed values: staging, prod" - }, 400 - reports_bucket_name = get_gtfs_validator_results_bucket( - env_param.lower() == "prod" - ) - - # Get validator version - validator_version = get_validator_version(validator_endpoint) - logging.info(f"Accessing bucket {bucket_name}") - - latest_datasets = get_latest_datasets_without_validation_reports( - db_session, validator_version, force_update - ) - logging.info("Retrieved %s latest datasets.", len(latest_datasets)) - - valid_latest_datasets = get_datasets_for_validation(latest_datasets) - logging.info("Retrieved %s blobs to update.", len(latest_datasets)) - - execution_triggered_datasets = execute_workflows( - valid_latest_datasets, validator_endpoint, bypass_db_update, reports_bucket_name - ) - response = { - "message": f"Validation report update needed for {len(valid_latest_datasets)} datasets and triggered for " - f"{len(execution_triggered_datasets)} datasets.", - "dataset_workflow_triggered": sorted(execution_triggered_datasets), - "datasets_not_updated": sorted( - [ - dataset_id - for _, dataset_id in valid_latest_datasets - if dataset_id not in execution_triggered_datasets - ] - ), - "ignored_datasets": sorted( - [ - dataset_id - for _, dataset_id in latest_datasets - if dataset_id not in valid_latest_datasets - ] - ), - } - return response, 200 - - -def get_validator_version(validator_url: str) -> str: - """ - Get the version of the validator - :param validator_url: The URL of the validator - :return: the version of the validator - """ - response = requests.get(f"{validator_url}/version") - validator_version = response.json()["version"] - logging.info("Validator version: %s", validator_version) - return validator_version - - -def get_latest_datasets_without_validation_reports( - session: sqlalchemy.orm.Session, - validator_version: str, - force_update: bool = False, -) -> List[Row[tuple[Any, Any]]]: - """ - Retrieve the latest datasets for each feed that do not have a validation report - :param session: The database session - :param validator_version: The version of the validator - :param force_update: Whether to force the update of the validation report - :return: A list of tuples containing the feed stable id and dataset stable id - """ - query = ( - session.query( - Gtfsfeed.stable_id, - Gtfsdataset.stable_id, - ) - .select_from(Gtfsfeed) - .join(Gtfsdataset, Gtfsfeed.latest_dataset_id == Gtfsdataset.id) - .outerjoin(Validationreport, Gtfsdataset.validation_reports) - .filter( - or_( - Validationreport.validator_version != validator_version, - Validationreport.id.is_(None), - force_update, - ) - ) - .distinct(Gtfsfeed.stable_id, Gtfsdataset.stable_id) - ) - return query.all() - - -def get_datasets_for_validation( - latest_datasets: List[Row[tuple[Any, Any]]], -) -> List[tuple[str, str]]: - """ - Get the valid dataset blobs that need their validation report to be updated - :param latest_datasets: List of tuples containing the feed stable id and dataset stable id - :return: List of tuples containing the feed stable id and dataset stable id - """ - report_update_needed = [] - storage_client = storage.Client() - bucket = storage_client.bucket(bucket_name) - - for feed_id, dataset_id in latest_datasets: - try: - dataset_blob = bucket.blob(f"{feed_id}/{dataset_id}/{dataset_id}.zip") - if not dataset_blob.exists(): - logging.warning(f"Dataset blob not found for {feed_id}/{dataset_id}") - else: - report_update_needed.append((feed_id, dataset_id)) - logging.info( - "Dataset blob found for %s/%s -- Adding to update list", - feed_id, - dataset_id, - ) - except Exception as e: - logging.error( - f"Error while accessing dataset blob for {feed_id}/{dataset_id}: {e}" - ) - return report_update_needed diff --git a/functions-python/update_validation_report/tests/test_update_validation_report.py b/functions-python/update_validation_report/tests/test_update_validation_report.py deleted file mode 100644 index 94bbbdac2..000000000 --- a/functions-python/update_validation_report/tests/test_update_validation_report.py +++ /dev/null @@ -1,128 +0,0 @@ -import os -import unittest -from unittest import mock -from unittest.mock import MagicMock, patch, Mock - -from faker import Faker -from google.cloud import storage - -from test_shared.test_utils.database_utils import default_db_url -from main import ( - get_latest_datasets_without_validation_reports, - get_datasets_for_validation, - update_validation_report, -) - -faker = Faker() - - -def _create_storage_blob(name, metadata): - """Create a mock storage blob.""" - blob = MagicMock(spec=storage.Blob) - blob.metadata = metadata - blob.name = name - blob.patch = Mock(return_value=None) - return blob - - -class TestUpdateReportProcessor(unittest.TestCase): - def test_get_latest_datasets(self): - """Test get_latest_datasets function.""" - session = MagicMock() - session.query.return_value.filter.return_value.all = MagicMock() - get_latest_datasets_without_validation_reports(session, "1.0.1") - session.query.assert_called_once() - - @patch("google.cloud.storage.Client") - def test_get_datasets_for_validation(self, mock_client): - """Test get_datasets_for_validation function""" - test_dataset_id = "dataset1" - test_feed_id = "feed1" - - def create_dataset_blob(name, exists): - mock_dataset_blob = Mock(spec=storage.Blob) - mock_dataset_blob.exists.return_value = exists - mock_dataset_blob.name = name - return mock_dataset_blob - - # Setup mock storage client and bucket - mock_bucket = Mock() - mock_client.return_value.bucket.return_value = mock_bucket - - # Setup mock blobs and existence results - mock_dataset_blob_exists = create_dataset_blob( - f"{test_feed_id}/{test_dataset_id}/{test_dataset_id}.zip", True - ) - mock_dataset_blob_not_exists = create_dataset_blob( - f"{test_feed_id}/{test_dataset_id}1/{test_dataset_id}1.zip", False - ) - - mock_bucket.blob.side_effect = lambda name: { - f"{test_feed_id}/{test_dataset_id}/{test_dataset_id}.zip": mock_dataset_blob_exists, - f"{test_feed_id}/{test_dataset_id}1/{test_dataset_id}1.zip": mock_dataset_blob_not_exists, - }[name] - - # Input parameters - nonexistent_dataset = (test_feed_id, f"{test_dataset_id}2") - latest_datasets = [ - (test_feed_id, test_dataset_id), - (test_feed_id, f"{test_dataset_id}1"), - nonexistent_dataset, - ] - - result = get_datasets_for_validation(latest_datasets) - - # Assertions - self.assertEqual(len(result), 1) - mock_dataset_blob_exists.exists.assert_called_once() - mock_dataset_blob_not_exists.exists.assert_called_once() - # Only the existing dataset should be returned - self.assertEqual(result[0][0], test_feed_id) - self.assertEqual(result[0][1], test_dataset_id) - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "WEB_VALIDATOR_URL": faker.url(), - "MAX_RETRY": "2", - "BATCH_SIZE": "2", - "SLEEP_TIME": "0", - }, - ) - @patch( - "main.get_latest_datasets_without_validation_reports", - autospec=True, - return_value=[("feed1", "dataset1")], - ) - @patch( - "main.get_datasets_for_validation", - autospec=True, - return_value=[("feed1", "dataset1")], - ) - @patch("google.cloud.storage.Blob", autospec=True) - @patch("requests.get", autospec=True) - @patch("google.cloud.storage.Client", autospec=True) - @patch("google.cloud.workflows_v1.WorkflowsClient", autospec=True) - @patch("google.cloud.workflows.executions_v1.ExecutionsClient", autospec=True) - @patch("google.cloud.workflows.executions_v1.Execution", autospec=True) - def test_update_validation_report( - self, - execution_mock, - executions_client_mock, - workflows_client_mock, - mock_client, - mock_get, - mock_blob, - mock_get_latest_datasets, - mock_get_datasets_for_validation, - ): - """Test update_validation_report function.""" - mock_get.return_value.json.return_value = {"version": "1.0.1"} - mock_request = MagicMock() - mock_request.get_json.return_value = {"validator_url": faker.url()} - response = update_validation_report(mock_request) - self.assertTrue("message" in response[0]) - self.assertTrue("dataset_workflow_triggered" in response[0]) - self.assertEqual(response[1], 200) - self.assertEqual(response[0]["dataset_workflow_triggered"], ["dataset1"]) diff --git a/functions-python/validation_to_ndjson/src/validation_report_converter.py b/functions-python/validation_to_ndjson/src/validation_report_converter.py index d5ccec8f9..9bf418f46 100644 --- a/functions-python/validation_to_ndjson/src/validation_report_converter.py +++ b/functions-python/validation_to_ndjson/src/validation_report_converter.py @@ -89,7 +89,7 @@ def process(self) -> None: # Convert the JSON data to a single NDJSON record (one line) storage_client = storage.Client(project=project_id) - bucket = storage_client.get_bucket(bucket_name) + bucket = storage_client.bucket(bucket_name) ndjson_content = json.dumps(json_data, separators=(",", ":")) ndjson_blob = bucket.blob(ndjson_blob_name) ndjson_blob.upload_from_string(ndjson_content + "\n") diff --git a/functions-python/validation_to_ndjson/tests/test_converter.py b/functions-python/validation_to_ndjson/tests/test_converter.py index 2d6c5c468..2336aaed6 100644 --- a/functions-python/validation_to_ndjson/tests/test_converter.py +++ b/functions-python/validation_to_ndjson/tests/test_converter.py @@ -89,7 +89,7 @@ def test_process_gtfs_report(self, mock_storage_client, mock_filter_json_by_sche "notices": [{"sampleNotices": '[{"id":1},{"id":2}]'}], } mock_bucket = MagicMock() - mock_storage_client().get_bucket.return_value = mock_bucket + mock_storage_client().bucket.return_value = mock_bucket self.gtfs_converter.process() @@ -120,7 +120,7 @@ def test_process_gbfs_report(self, mock_storage_client, mock_filter_json_by_sche "notices": [{"sampleNotices": '[{"id":1},{"id":2}]'}], } mock_bucket = MagicMock() - mock_storage_client().get_bucket.return_value = mock_bucket + mock_storage_client().bucket.return_value = mock_bucket self.gbfs_converter.process() diff --git a/infra/functions-python/main.tf b/infra/functions-python/main.tf index 8df12c216..129226320 100644 --- a/infra/functions-python/main.tf +++ b/infra/functions-python/main.tf @@ -39,9 +39,6 @@ locals { function_process_validation_report_zip = "${path.module}/../../functions-python/process_validation_report/.dist/process_validation_report.zip" public_hosted_datasets_url = lower(var.environment) == "prod" ? "https://${var.public_hosted_datasets_dns}" : "https://${var.environment}-${var.public_hosted_datasets_dns}" - function_update_validation_report_config = jsondecode(file("${path.module}/../../functions-python/update_validation_report/function_config.json")) - function_update_validation_report_zip = "${path.module}/../../functions-python/update_validation_report/.dist/update_validation_report.zip" - function_gbfs_validation_report_config = jsondecode(file("${path.module}/../../functions-python/gbfs_validator/function_config.json")) function_gbfs_validation_report_zip = "${path.module}/../../functions-python/gbfs_validator/.dist/gbfs_validator.zip" @@ -75,7 +72,6 @@ locals { local.function_tokens_config.secret_environment_variables, local.function_process_validation_report_config.secret_environment_variables, local.function_gbfs_validation_report_config.secret_environment_variables, - local.function_update_validation_report_config.secret_environment_variables, local.function_backfill_dataset_service_date_range_config.secret_environment_variables, local.function_update_feed_status_config.secret_environment_variables, local.function_export_csv_config.secret_environment_variables, @@ -160,14 +156,7 @@ resource "google_storage_bucket_object" "process_validation_report_zip" { source = local.function_process_validation_report_zip } -# 4. Update validation report -resource "google_storage_bucket_object" "update_validation_report_zip" { - bucket = google_storage_bucket.functions_bucket.name - name = "update-validation-report-${substr(filebase64sha256(local.function_update_validation_report_zip), 0, 10)}.zip" - source = local.function_update_validation_report_zip -} - -# 5. GBFS validation report +# 4. GBFS validation report (was previously #5) resource "google_storage_bucket_object" "gbfs_validation_report_zip" { bucket = google_storage_bucket.functions_bucket.name name = "gbfs-validator-${substr(filebase64sha256(local.function_gbfs_validation_report_zip), 0, 10)}.zip" @@ -427,55 +416,7 @@ resource "google_cloudfunctions2_function" "compute_validation_report_counters" } } -# 4. functions/update_validation_report cloud function -resource "google_cloudfunctions2_function" "update_validation_report" { - location = var.gcp_region - name = local.function_update_validation_report_config.name - description = local.function_update_validation_report_config.description - depends_on = [google_secret_manager_secret_iam_member.secret_iam_member] - project = var.project_id - build_config { - runtime = var.python_runtime - entry_point = local.function_update_validation_report_config.entry_point - source { - storage_source { - bucket = google_storage_bucket.functions_bucket.name - object = google_storage_bucket_object.update_validation_report_zip.name - } - } - } - service_config { - available_memory = local.function_update_validation_report_config.memory - available_cpu = local.function_update_validation_report_config.available_cpu - timeout_seconds = local.function_update_validation_report_config.timeout - vpc_connector = data.google_vpc_access_connector.vpc_connector.id - vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" - - environment_variables = { - ENV = var.environment - MAX_RETRY = 10 - BATCH_SIZE = 5 - WEB_VALIDATOR_URL = var.validator_endpoint - # prevents multiline logs from being truncated on GCP console - PYTHONNODEBUGRANGES = 0 - } - dynamic "secret_environment_variables" { - for_each = local.function_update_validation_report_config.secret_environment_variables - content { - key = secret_environment_variables.value["key"] - project_id = var.project_id - secret = lookup(secret_environment_variables.value, "secret", "${upper(var.environment)}_${secret_environment_variables.value["key"]}") - version = "latest" - } - } - service_account_email = google_service_account.functions_service_account.email - max_instance_request_concurrency = local.function_update_validation_report_config.max_instance_request_concurrency - max_instance_count = local.function_update_validation_report_config.max_instance_count - min_instance_count = local.function_update_validation_report_config.min_instance_count - } -} - -# 5. functions/gbfs_validator cloud function +# 4. functions/gbfs_validator cloud function # 5.1 Create Pub/Sub topic resource "google_pubsub_topic" "validate_gbfs_feed" { name = "validate-gbfs-feed" @@ -1150,6 +1091,7 @@ resource "google_cloudfunctions2_function" "tasks_executor" { DATASETS_BUCKET_NAME = "${var.datasets_bucket_name}-${var.environment}" GBFS_SNAPSHOTS_BUCKET_NAME = google_storage_bucket.gbfs_snapshots_bucket.name PMTILES_BUILDER_QUEUE = google_cloud_tasks_queue.pmtiles_builder_task_queue.name + TASK_RUN_SYNC_QUEUE = google_cloud_tasks_queue.task_run_sync_queue.name SERVICE_ACCOUNT_EMAIL = google_service_account.functions_service_account.email GCP_REGION = var.gcp_region TDG_API_TOKEN = var.tdg_api_token @@ -1371,43 +1313,43 @@ output "function_tokens_name" { } -# Task queue to invoke update_validation_report function -resource "google_cloud_tasks_queue" "update_validation_report_task_queue" { +# Task queue to invoke refresh_materialized_view function +resource "google_cloud_tasks_queue" "refresh_materialized_view_task_queue" { project = var.project_id location = var.gcp_region - name = "update-validation-report-task-queue" + name = "refresh-materialized-view-task-queue-${var.environment}-${local.deployment_timestamp}" rate_limits { max_concurrent_dispatches = 1 - max_dispatches_per_second = 1 + max_dispatches_per_second = 0.5 } retry_config { - # This will make the cloud task retry for ~1 hour - max_attempts = 31 + # ~22 minutes total: 120 + 240 + 480 + 480 = 1320s (initial attempt + 4 retries) + max_attempts = 5 min_backoff = "120s" - max_backoff = "120s" + max_backoff = "480s" max_doublings = 2 } } -# Task queue to invoke refresh_materialized_view function -resource "google_cloud_tasks_queue" "refresh_materialized_view_task_queue" { +# Queue for Cloud-Tasks-driven task run status sync +# Retries every 10 minutes (constant interval) until the handler returns 200. +resource "google_cloud_tasks_queue" "task_run_sync_queue" { project = var.project_id location = var.gcp_region - name = "refresh-materialized-view-task-queue-${var.environment}-${local.deployment_timestamp}" + name = "task-run-sync-queue-${var.environment}-${local.deployment_timestamp}" rate_limits { - max_concurrent_dispatches = 1 - max_dispatches_per_second = 0.5 + max_concurrent_dispatches = 5 + max_dispatches_per_second = 1 } retry_config { - # ~22 minutes total: 120 + 240 + 480 + 480 = 1320s (initial attempt + 4 retries) - max_attempts = 5 - min_backoff = "120s" - max_backoff = "480s" - max_doublings = 2 + max_attempts = 100 + min_backoff = "600s" + max_backoff = "600s" + max_doublings = 0 } } diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index 6b97e1cbf..3175e0773 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -105,6 +105,8 @@ + + diff --git a/liquibase/changes/feat_task_execution_log.sql b/liquibase/changes/feat_task_execution_log.sql new file mode 100644 index 000000000..ae695650a --- /dev/null +++ b/liquibase/changes/feat_task_execution_log.sql @@ -0,0 +1,39 @@ +-- Generic task run tracking table. +-- One record per orchestration run (process level). +-- Mirrors the BatchExecution concept from DatasetTraceService (Datastore), +-- allowing future migration of batch_process_dataset, batch_datasets, gbfs_validator. +CREATE TABLE IF NOT EXISTS task_run ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + task_name VARCHAR NOT NULL, + run_id VARCHAR NOT NULL, + status VARCHAR NOT NULL, + total_count INTEGER, + params JSONB, + created_at TIMESTAMP DEFAULT NOW(), + completed_at TIMESTAMP, + UNIQUE (task_name, run_id) +); + +-- Generic task execution log table. +-- One record per entity/workflow execution within a run. +-- entity_id is nullable for tasks that do not operate on a specific entity. +-- Mirrors the DatasetTrace concept from DatasetTraceService (Datastore). +CREATE TABLE IF NOT EXISTS task_execution_log ( + id SERIAL PRIMARY KEY, + task_run_id UUID REFERENCES task_run(id) ON DELETE CASCADE, + task_name VARCHAR NOT NULL, + entity_id VARCHAR, + run_id VARCHAR NOT NULL, + status VARCHAR NOT NULL, + execution_ref VARCHAR, + error_message TEXT, + metadata JSONB, + triggered_at TIMESTAMP DEFAULT NOW(), + completed_at TIMESTAMP, + UNIQUE (task_name, entity_id, run_id) +); + +CREATE INDEX IF NOT EXISTS ix_task_run_task_name_run_id ON task_run (task_name, run_id); +CREATE INDEX IF NOT EXISTS ix_task_execution_log_task_run_id ON task_execution_log (task_run_id); +CREATE INDEX IF NOT EXISTS ix_task_execution_log_task_name_entity_run ON task_execution_log (task_name, entity_id, run_id); +CREATE INDEX IF NOT EXISTS ix_task_execution_log_status ON task_execution_log (status);