diff --git a/socket_basics/core/notification/github_pr_notifier.py b/socket_basics/core/notification/github_pr_notifier.py index 555d03e..8c258bc 100644 --- a/socket_basics/core/notification/github_pr_notifier.py +++ b/socket_basics/core/notification/github_pr_notifier.py @@ -100,6 +100,18 @@ def notify(self, facts: Dict[str, Any]) -> None: # Update existing comments with new section content for comment_id, updated_body in comment_updates.items(): + # Detect whether content actually changed before making the API call + original_body = next( + (c.get('body', '') for c in existing_comments if c.get('id') == comment_id), + '', + ) + if original_body == updated_body: + logger.info( + 'GithubPRNotifier: comment %s content unchanged; skipping update', + comment_id, + ) + continue + success = self._update_comment(pr_number, comment_id, updated_body) if success: logger.info('GithubPRNotifier: updated existing comment %s', comment_id) diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py new file mode 100644 index 0000000..59bf6ce --- /dev/null +++ b/socket_basics/core/triage.py @@ -0,0 +1,269 @@ +"""Triage filtering for Socket Security Basics. + +Streams the full scan from the Socket API to obtain alert keys, fetches +triage entries, and filters local scan components whose alerts have been +triaged (state: ignore or monitor). +""" + +import logging +from typing import Any, Dict, List, Set, Tuple + +logger = logging.getLogger(__name__) + +# Triage states that cause a finding to be removed from reports +_SUPPRESSED_STATES = {"ignore", "monitor"} + + +# ------------------------------------------------------------------ +# API helpers +# ------------------------------------------------------------------ + +def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: + """Fetch all triage alert entries from the Socket API, handling pagination. + + Args: + sdk: Initialized socketdev SDK instance. + org_slug: Organization slug for the API call. + + Returns: + List of triage entry dicts. + """ + all_entries: List[Dict[str, Any]] = [] + page = 1 + per_page = 100 + + while True: + try: + response = sdk.triage.list_alert_triage( + org_slug, + {"per_page": per_page, "page": page}, + ) + except Exception as exc: + # Handle insufficient permissions gracefully so the scan + # continues without triage filtering. + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Triage API access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to fetch triage data (page %d): %s", page, exc) + break + + if not isinstance(response, dict): + logger.warning("Unexpected triage API response type: %s", type(response)) + break + + results = response.get("results") or [] + all_entries.extend(results) + + next_page = response.get("nextPage") + if next_page is None: + break + page = int(next_page) + + logger.debug("Fetched %d triage entries for org %s", len(all_entries), org_slug) + return all_entries + + +def stream_full_scan_alerts( + sdk: Any, org_slug: str, full_scan_id: str +) -> Dict[str, List[Dict[str, Any]]]: + """Stream a full scan and extract alert keys grouped by artifact. + + Returns: + Mapping of artifact ID to list of alert dicts. Each alert dict + contains at minimum ``key`` and ``type``. The artifact metadata + (name, version, type, etc.) is included under a ``_artifact`` key + in every alert dict for downstream matching. + """ + try: + # use_types=False returns a plain dict keyed by artifact ID + response = sdk.fullscans.stream(org_slug, full_scan_id, use_types=False) + except Exception as exc: + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Full scan stream access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to stream full scan %s: %s", full_scan_id, exc) + return {} + + if not isinstance(response, dict): + logger.warning("Unexpected full scan stream response type: %s", type(response)) + return {} + + artifact_alerts: Dict[str, List[Dict[str, Any]]] = {} + for artifact_id, artifact in response.items(): + if not isinstance(artifact, dict): + continue + alerts = artifact.get("alerts") or [] + if not alerts: + continue + meta = { + "artifact_id": artifact_id, + "artifact_name": artifact.get("name"), + "artifact_version": artifact.get("version"), + "artifact_type": artifact.get("type"), + "artifact_namespace": artifact.get("namespace"), + "artifact_subpath": artifact.get("subPath") or artifact.get("subpath"), + } + enriched = [] + for a in alerts: + if isinstance(a, dict) and a.get("key"): + enriched.append({**a, "_artifact": meta}) + if enriched: + artifact_alerts[artifact_id] = enriched + + total_alerts = sum(len(v) for v in artifact_alerts.values()) + logger.debug( + "Streamed full scan %s: %d artifact(s), %d alert(s) with keys", + full_scan_id, + len(artifact_alerts), + total_alerts, + ) + return artifact_alerts + + +# ------------------------------------------------------------------ +# TriageFilter +# ------------------------------------------------------------------ + +class TriageFilter: + """Cross-references Socket alert keys against triage entries and + maps triaged alerts back to local scan components.""" + + def __init__( + self, + triage_entries: List[Dict[str, Any]], + artifact_alerts: Dict[str, List[Dict[str, Any]]], + ) -> None: + # Build set of suppressed alert keys + self.triaged_keys: Set[str] = set() + for entry in triage_entries: + state = (entry.get("state") or "").lower() + key = entry.get("alert_key") + if state in _SUPPRESSED_STATES and key: + self.triaged_keys.add(key) + + # Flatten all Socket alerts for lookup + self._socket_alerts: List[Dict[str, Any]] = [] + for alerts in artifact_alerts.values(): + self._socket_alerts.extend(alerts) + + # Build a mapping from (artifact_id, alert_type) to triaged status + # for fast lookups when matching against local components + self._triaged_by_artifact: Dict[str, Set[str]] = {} + for alert in self._socket_alerts: + if alert.get("key") in self.triaged_keys: + art_id = alert.get("_artifact", {}).get("artifact_id", "") + alert_type = alert.get("type") or "" + self._triaged_by_artifact.setdefault(art_id, set()).add(alert_type) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def filter_components( + self, components: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], int]: + """Remove triaged alerts from local components. + + Matches local components to Socket artifacts by component ID, then + checks each local alert against the set of triaged alert types for + that artifact. + + Returns: + (filtered_components, triaged_count) + """ + if not self.triaged_keys: + return components, 0 + + # Build lookup: component id -> set of triaged Socket alert types + triaged_types_by_component = self._map_components_to_triaged_types(components) + + if not triaged_types_by_component: + logger.debug( + "No local components matched Socket artifacts with triaged alerts" + ) + return components, 0 + + filtered: List[Dict[str, Any]] = [] + triaged_count = 0 + + for comp in components: + comp_id = comp.get("id") or "" + triaged_types = triaged_types_by_component.get(comp_id) + + if triaged_types is None: + # Component had no triaged alerts; keep as-is + filtered.append(comp) + continue + + remaining_alerts: List[Dict[str, Any]] = [] + for alert in comp.get("alerts", []): + if self._local_alert_is_triaged(alert, triaged_types): + triaged_count += 1 + else: + remaining_alerts.append(alert) + + if remaining_alerts: + new_comp = dict(comp) + new_comp["alerts"] = remaining_alerts + filtered.append(new_comp) + + return filtered, triaged_count + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _map_components_to_triaged_types( + self, components: List[Dict[str, Any]] + ) -> Dict[str, Set[str]]: + """Map local component IDs to the set of triaged Socket alert types. + + Matches by component ``id`` (which is typically a hash that Socket + also uses as the artifact ID). + """ + local_ids = {comp.get("id") for comp in components if comp.get("id")} + result: Dict[str, Set[str]] = {} + for comp_id in local_ids: + triaged = self._triaged_by_artifact.get(comp_id) + if triaged: + result[comp_id] = triaged + return result + + @staticmethod + def _local_alert_is_triaged( + alert: Dict[str, Any], triaged_types: Set[str] + ) -> bool: + """Check if a local alert matches any of the triaged Socket alert types. + + Socket alert ``type`` values (e.g. ``badEncoding``, ``cve``) are + compared against the local alert's ``type`` field. When the local + alert type is too generic (``"generic"`` or ``"vulnerability"``), + we fall back to matching on ``title``, ``props.ruleId``, or + ``props.vulnerabilityId``. + """ + # Direct type match + local_type = alert.get("type") or "" + if local_type and local_type not in ("generic", "vulnerability"): + return local_type in triaged_types + + # Fallback: match candidate fields against triaged types + props = alert.get("props") or {} + candidates = { + v for v in ( + alert.get("title"), + props.get("ruleId"), + props.get("detectorName"), + props.get("vulnerabilityId"), + props.get("cveId"), + ) + if v + } + return bool(candidates & triaged_types) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index a7f7f04..5ca722e 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -17,7 +17,7 @@ import sys import os from pathlib import Path -from typing import Dict, Any, Optional +from typing import Dict, Any, List, Optional import hashlib try: # Python 3.11+ @@ -356,9 +356,10 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) logger.error(f"Error creating full scan: {error_msg}") raise Exception(f"Error creating full scan: {error_msg}") - # Extract the scan ID and HTML URL from the response - scan_id = getattr(res, 'id', None) - html_url = getattr(res, 'html_url', None) + # SDK CreateFullScanResponse nests metadata under .data + data = getattr(res, 'data', None) or res + scan_id = getattr(data, 'id', None) + html_url = getattr(data, 'html_report_url', None) or getattr(data, 'html_url', None) logger.debug(f"Extracted from object: scan_id={scan_id}, html_url={html_url}") if scan_id: @@ -378,6 +379,283 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) return results + def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: + """Filter out triaged alerts and regenerate notifications. + + Streams the full scan from the Socket API to obtain alert keys, + cross-references them with triage entries, removes suppressed + alerts from local components, regenerates connector notifications, + and injects a triage summary into github_pr content. + + Args: + results: Current scan results dict (components + notifications). + + Returns: + Updated results dict with triaged findings removed. + """ + socket_api_key = self.config.get('socket_api_key') + socket_org = self.config.get('socket_org') + full_scan_id = results.get('full_scan_id') + + if not socket_api_key or not socket_org: + logger.debug("Skipping triage filter: missing socket_api_key or socket_org") + return results + + if not full_scan_id: + logger.debug("Skipping triage filter: no full_scan_id in results") + return results + + # Import SDK and triage helpers + try: + from socketdev import socketdev + except ImportError: + logger.debug("socketdev SDK not available; skipping triage filter") + return results + + try: + from .core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts + except ImportError: + from socket_basics.core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts + + sdk = socketdev(token=socket_api_key, timeout=100) + + # Fetch triage entries and stream full scan alert keys in sequence + triage_entries = fetch_triage_data(sdk, socket_org) + if not triage_entries: + logger.info("No triage entries found; skipping filter") + return results + + suppressed_count = sum( + 1 for e in triage_entries + if (e.get("state") or "").lower() in ("ignore", "monitor") + ) + logger.info( + "Fetched %d triage entries (%d with suppressed state)", + len(triage_entries), + suppressed_count, + ) + + artifact_alerts = stream_full_scan_alerts(sdk, socket_org, full_scan_id) + if not artifact_alerts: + logger.info("No alert keys returned from full scan stream; skipping filter") + return results + + triage_filter = TriageFilter(triage_entries, artifact_alerts) + original_components = results.get('components', []) + original_alert_count = sum( + len(c.get('alerts', [])) for c in original_components + ) + filtered_components, triaged_count = triage_filter.filter_components( + original_components + ) + + if triaged_count == 0: + logger.info( + "Triage filter matched 0 of %d finding(s); no changes applied", + original_alert_count, + ) + return results + + remaining_alert_count = sum( + len(c.get('alerts', [])) for c in filtered_components + ) + logger.info( + "Triage filter removed %d finding(s); %d finding(s) remain", + triaged_count, + remaining_alert_count, + ) + results['components'] = filtered_components + results['triaged_count'] = triaged_count + + # Regenerate notifications from the filtered components + self._regenerate_notifications(results, filtered_components, triaged_count) + + return results + + def _regenerate_notifications( + self, + results: Dict[str, Any], + filtered_components: List[Dict[str, Any]], + triaged_count: int, + ) -> None: + """Regenerate connector notifications from filtered components. + + Groups components by their connector origin (via the ``generatedBy`` + field on alerts), calls each connector's ``generate_notifications``, + merges the results, and injects a triage summary into github_pr + content. + + Always replaces ``results['notifications']`` so stale pre-filter + notifications are never forwarded to notifiers. + """ + connector_components: Dict[str, List[Dict[str, Any]]] = {} + unmapped_count = 0 + for comp in filtered_components: + mapped = False + for alert in comp.get('alerts', []): + gen = alert.get('generatedBy') or '' + connector_name = self._connector_name_from_generated_by(gen) + if connector_name: + connector_components.setdefault(connector_name, []).append(comp) + mapped = True + break # one mapping per component is enough + if not mapped: + unmapped_count += 1 + + if unmapped_count: + logger.debug( + "Triage regen: %d component(s) could not be mapped to a connector", + unmapped_count, + ) + + logger.info( + "Regenerating notifications for %d connector(s): %s", + len(connector_components), + ", ".join(connector_components.keys()) or "(none)", + ) + + merged_notifications: Dict[str, list] = {} + + for connector_name, comps in connector_components.items(): + connector = self.connector_manager.loaded_connectors.get(connector_name) + if connector is None: + logger.warning( + "Connector %s not in loaded_connectors (available: %s); " + "cannot regenerate its notifications", + connector_name, + ", ".join(self.connector_manager.loaded_connectors.keys()), + ) + continue + + if not hasattr(connector, 'generate_notifications'): + logger.debug("Connector %s has no generate_notifications", connector_name) + continue + + try: + if connector_name == 'trivy': + item_name, scan_type = self._derive_trivy_params(comps) + notifs = connector.generate_notifications(comps, item_name, scan_type) + else: + notifs = connector.generate_notifications(comps) + except Exception: + logger.exception("Failed to regenerate notifications for %s", connector_name) + continue + + if not isinstance(notifs, dict): + continue + + notifier_keys = [k for k, v in notifs.items() if v] + logger.debug( + "Connector %s produced notifications for: %s", + connector_name, + ", ".join(notifier_keys) or "(empty)", + ) + + for notifier_key, payload in notifs.items(): + if notifier_key not in merged_notifications: + merged_notifications[notifier_key] = payload + elif isinstance(merged_notifications[notifier_key], list) and isinstance(payload, list): + merged_notifications[notifier_key].extend(payload) + + # Inject triage summary into github_pr notification content + full_scan_url = results.get('full_scan_html_url', '') + self._inject_triage_summary(merged_notifications, triaged_count, full_scan_url) + + # Always replace notifications so stale pre-filter content is never + # forwarded to notifiers. An empty dict is valid and means every + # finding was triaged. + results['notifications'] = merged_notifications + + @staticmethod + def _connector_name_from_generated_by(generated_by: str) -> str | None: + """Map a generatedBy value back to its connector name.""" + gb = generated_by.lower() + if gb.startswith('opengrep') or gb.startswith('sast'): + return 'opengrep' + if gb == 'trufflehog': + return 'trufflehog' + if gb.startswith('trivy'): + return 'trivy' + if gb == 'socket-tier1': + return 'socket_tier1' + return None + + def _derive_trivy_params( + self, components: List[Dict[str, Any]] + ) -> tuple: + """Derive item_name and scan_type for Trivy notification regeneration.""" + scan_type = 'image' + for comp in components: + for alert in comp.get('alerts', []): + props = alert.get('props') or {} + st = props.get('scanType', '') + if st: + scan_type = st + break + if scan_type != 'image': + break + + item_name = "Unknown" + images_str = ( + self.config.get('container_images', '') + or self.config.get('container_images_to_scan', '') + or self.config.get('docker_images', '') + ) + if images_str: + if isinstance(images_str, list): + item_name = images_str[0] if images_str else "Unknown" + else: + images = [img.strip() for img in str(images_str).split(',') if img.strip()] + item_name = images[0] if images else "Unknown" + else: + dockerfiles = self.config.get('dockerfiles', '') + if dockerfiles: + if isinstance(dockerfiles, list): + item_name = dockerfiles[0] if dockerfiles else "Unknown" + else: + docker_list = [df.strip() for df in str(dockerfiles).split(',') if df.strip()] + item_name = docker_list[0] if docker_list else "Unknown" + + if scan_type == 'vuln' and item_name == "Unknown": + try: + item_name = os.path.basename(str(self.config.workspace)) + except Exception: + item_name = "Workspace" + + return item_name, scan_type + + @staticmethod + def _inject_triage_summary( + notifications: Dict[str, list], + triaged_count: int, + full_scan_url: str, + ) -> None: + """Insert a triage summary line into github_pr notification content.""" + gh_items = notifications.get('github_pr') + if not gh_items or not isinstance(gh_items, list): + return + + dashboard_link = full_scan_url or "https://socket.dev/dashboard" + summary_line = ( + f"\n> :white_check_mark: **{triaged_count} finding(s) triaged** " + f"via [Socket Dashboard]({dashboard_link}) and removed from this report.\n" + ) + + for item in gh_items: + if not isinstance(item, dict) or 'content' not in item: + continue + content = item['content'] + # Insert after the first markdown heading line (# Title) + lines = content.split('\n') + insert_idx = 0 + for i, line in enumerate(lines): + if line.strip().startswith('# '): + insert_idx = i + 1 + break + lines.insert(insert_idx, summary_line) + item['content'] = '\n'.join(lines) + + def main(): """Main entry point""" parser = parse_cli_args() @@ -429,6 +707,12 @@ def main(): except Exception: logger.exception("Failed to submit socket facts file") + # Filter out triaged alerts before notifying + try: + results = scanner.apply_triage_filter(results) + except Exception: + logger.exception("Failed to apply triage filter") + # Optionally upload to S3 if requested try: enable_s3 = getattr(args, 'enable_s3_upload', False) or config.get('enable_s3_upload', False) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_triage.py b/tests/test_triage.py new file mode 100644 index 0000000..a05e6e5 --- /dev/null +++ b/tests/test_triage.py @@ -0,0 +1,592 @@ +"""Tests for socket_basics.core.triage module.""" + +import logging +import pytest +from socket_basics.core.triage import ( + TriageFilter, + fetch_triage_data, + stream_full_scan_alerts, +) + + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + +ARTIFACT_ID = "abc123" + + +def _make_component( + comp_id: str = ARTIFACT_ID, + name: str = "lodash", + comp_type: str = "npm", + version: str = "4.17.21", + alerts: list | None = None, +) -> dict: + return { + "id": comp_id, + "name": name, + "version": version, + "type": comp_type, + "qualifiers": {"ecosystem": comp_type, "version": version}, + "alerts": alerts or [], + } + + +def _make_local_alert( + title: str = "badEncoding", + alert_type: str = "badEncoding", + severity: str = "high", + rule_id: str | None = None, + detector_name: str | None = None, + cve_id: str | None = None, + generated_by: str = "opengrep-python", +) -> dict: + props: dict = {} + if rule_id: + props["ruleId"] = rule_id + if detector_name: + props["detectorName"] = detector_name + if cve_id: + props["cveId"] = cve_id + return { + "title": title, + "type": alert_type, + "severity": severity, + "generatedBy": generated_by, + "props": props, + } + + +def _make_triage_entry( + alert_key: str, + state: str = "ignore", +) -> dict: + return { + "uuid": "test-uuid", + "alert_key": alert_key, + "state": state, + "note": "", + "organization_id": "test-org", + } + + +def _make_artifact_alerts( + artifact_id: str = ARTIFACT_ID, + alerts: list[dict] | None = None, + name: str = "lodash", + version: str = "4.17.21", + pkg_type: str = "npm", +) -> dict[str, list[dict]]: + """Build an artifact_alerts mapping with enriched _artifact metadata.""" + meta = { + "artifact_id": artifact_id, + "artifact_name": name, + "artifact_version": version, + "artifact_type": pkg_type, + "artifact_namespace": None, + "artifact_subpath": None, + } + enriched = [{**a, "_artifact": meta} for a in (alerts or [])] + return {artifact_id: enriched} + + +def _socket_alert(key: str, alert_type: str) -> dict: + """Create a minimal Socket alert dict (as returned by the full scan stream).""" + return {"key": key, "type": alert_type} + + +# --------------------------------------------------------------------------- +# TriageFilter construction +# --------------------------------------------------------------------------- + +class TestTriageFilterInit: + def test_builds_triaged_keys_for_ignore(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-1" in tf.triaged_keys + + def test_builds_triaged_keys_for_monitor(self): + entries = [_make_triage_entry("hash-2", state="monitor")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-2", "cve")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-2" in tf.triaged_keys + + def test_excludes_block_warn_inherit_states(self): + entries = [ + _make_triage_entry("h1", state="block"), + _make_triage_entry("h2", state="warn"), + _make_triage_entry("h3", state="inherit"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("h1", "a"), + _socket_alert("h2", "b"), + _socket_alert("h3", "c"), + ] + ) + tf = TriageFilter(entries, artifact_alerts) + assert tf.triaged_keys == set() + + def test_builds_triaged_by_artifact_mapping(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-1", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) + assert "art-1" in tf._triaged_by_artifact + assert "badEncoding" in tf._triaged_by_artifact["art-1"] + + def test_no_entries_means_empty_triaged_keys(self): + tf = TriageFilter([], {}) + assert tf.triaged_keys == set() + + def test_entry_without_alert_key_ignored(self): + entries = [{"state": "ignore", "alert_key": None}] + tf = TriageFilter(entries, {}) + assert tf.triaged_keys == set() + + +# --------------------------------------------------------------------------- +# TriageFilter._local_alert_is_triaged +# --------------------------------------------------------------------------- + +class TestLocalAlertIsTriaged: + def test_direct_type_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="badEncoding") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_direct_type_no_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="cve") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False + + def test_generic_type_falls_back_to_title(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(title="badEncoding", alert_type="generic") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_vulnerability_type_falls_back_to_cve(self): + triaged_types = {"CVE-2024-1234"} + alert = _make_local_alert( + title="Some Vuln", alert_type="vulnerability", cve_id="CVE-2024-1234" + ) + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_generic_type_falls_back_to_rule_id(self): + triaged_types = {"python.lang.security.audit.xss"} + alert = _make_local_alert( + title="XSS", alert_type="generic", + rule_id="python.lang.security.audit.xss", + ) + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_generic_type_falls_back_to_detector_name(self): + triaged_types = {"AWS"} + alert = _make_local_alert( + title="AWS Key", alert_type="generic", detector_name="AWS" + ) + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_no_fallback_candidates_returns_false(self): + triaged_types = {"something"} + alert = {"type": "generic", "props": {}} + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False + + +# --------------------------------------------------------------------------- +# TriageFilter.filter_components +# --------------------------------------------------------------------------- + +class TestFilterComponents: + def test_removes_triaged_alert_by_type(self): + """Component ID matches artifact, triaged alert type matches local alert type.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="kept", alert_type="otherIssue"), + ], + ) + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 1 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["title"] == "kept" + + def test_removes_component_when_all_alerts_triaged(self): + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[_make_local_alert(alert_type="badEncoding")], + ) + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 0 + + def test_no_triage_entries_returns_original(self): + tf = TriageFilter([], {}) + comp = _make_component(alerts=[_make_local_alert()]) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert filtered == [comp] + + def test_component_id_mismatch_keeps_all_alerts(self): + """When local component ID doesn't match any artifact, nothing is filtered.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="different-artifact", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id="unrelated-comp-id", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert len(filtered) == 1 + + def test_multiple_components_mixed(self): + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-a", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp1 = _make_component( + comp_id="art-a", name="a", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) + comp2 = _make_component( + comp_id="art-b", name="b", + alerts=[_make_local_alert(alert_type="otherIssue")], + ) + comp3 = _make_component( + comp_id="art-a", name="c", + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="keepMe", alert_type="keepMe"), + ], + ) + + filtered, count = tf.filter_components([comp1, comp2, comp3]) + assert count == 2 + assert len(filtered) == 2 + names = [c["name"] for c in filtered] + assert "a" not in names + assert "b" in names + assert "c" in names + + def test_multiple_triaged_alert_types_on_same_artifact(self): + entries = [ + _make_triage_entry("hash-1", state="ignore"), + _make_triage_entry("hash-2", state="monitor"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("hash-1", "badEncoding"), + _socket_alert("hash-2", "cve"), + ], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(alert_type="cve"), + _make_local_alert(title="safe", alert_type="safe"), + ], + ) + filtered, count = tf.filter_components([comp]) + assert count == 2 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["type"] == "safe" + + +# --------------------------------------------------------------------------- +# stream_full_scan_alerts +# --------------------------------------------------------------------------- + +class TestStreamFullScanAlerts: + def test_parses_artifacts_and_alerts(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "artifact-1": { + "name": "lodash", + "version": "4.17.21", + "type": "npm", + "namespace": None, + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"key": "hash-b", "type": "cve"}, + ], + }, + "artifact-2": { + "name": "express", + "version": "4.18.0", + "type": "npm", + "namespace": None, + "alerts": [], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "my-org", "scan-123") + assert "artifact-1" in result + assert "artifact-2" not in result # empty alerts filtered out + assert len(result["artifact-1"]) == 2 + assert result["artifact-1"][0]["key"] == "hash-a" + assert result["artifact-1"][0]["_artifact"]["artifact_name"] == "lodash" + + def test_skips_alerts_without_key(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0.0", + "type": "npm", + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"type": "noKey"}, # missing key + {"key": "", "type": "emptyKey"}, # empty key + ], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert len(result["art-1"]) == 1 + + def test_access_denied_returns_empty(self, caplog): + class APIAccessDenied(Exception): + pass + + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise APIAccessDenied("Forbidden") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + with caplog.at_level(logging.DEBUG): + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + + assert result == {} + info_msgs = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_msgs) + + def test_api_error_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise RuntimeError("Network failure") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_non_dict_response_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return "unexpected string" + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_subpath_handling(self): + """Supports both camelCase and lowercase subpath field names.""" + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0", + "type": "npm", + "subPath": "src/lib", + "alerts": [{"key": "k1", "type": "t1"}], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result["art-1"][0]["_artifact"]["artifact_subpath"] == "src/lib" + + +# --------------------------------------------------------------------------- +# fetch_triage_data +# --------------------------------------------------------------------------- + +class TestFetchTriageData: + def test_single_page(self): + class FakeTriageAPI: + def list_alert_triage(self, org, params): + return {"results": [{"alert_key": "a", "state": "ignore"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + assert entries[0]["alert_key"] == "a" + + def test_pagination(self): + class FakeTriageAPI: + def __init__(self): + self.call_count = 0 + + def list_alert_triage(self, org, params): + self.call_count += 1 + if params.get("page") == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + return {"results": [{"alert_key": "b"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 2 + + def test_api_error_returns_partial(self): + class FakeTriageAPI: + def __init__(self): + self.calls = 0 + + def list_alert_triage(self, org, params): + self.calls += 1 + if self.calls == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + raise RuntimeError("API error") + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + + def test_access_denied_returns_empty_and_logs_info(self, caplog): + """Insufficient permissions should log an info message (not an error) and return empty.""" + + class APIAccessDenied(Exception): + pass + + class FakeTriageAPI: + def list_alert_triage(self, org, params): + raise APIAccessDenied("Insufficient permissions.") + + class FakeSDK: + triage = FakeTriageAPI() + + with caplog.at_level(logging.DEBUG): + entries = fetch_triage_data(FakeSDK(), "my-org") + + assert entries == [] + info_messages = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_messages) + error_messages = [r for r in caplog.records if r.levelno >= logging.ERROR] + assert not error_messages + + +# --------------------------------------------------------------------------- +# SecurityScanner._connector_name_from_generated_by +# --------------------------------------------------------------------------- + +class TestConnectorNameMapping: + def test_opengrep_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("opengrep-python") == "opengrep" + assert SecurityScanner._connector_name_from_generated_by("sast-generic") == "opengrep" + + def test_trufflehog(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trufflehog") == "trufflehog" + + def test_trivy_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trivy-dockerfile") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-image") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-npm") == "trivy" + + def test_socket_tier1(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("socket-tier1") == "socket_tier1" + + def test_unknown_returns_none(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("unknown-tool") is None + + +# --------------------------------------------------------------------------- +# SecurityScanner._inject_triage_summary +# --------------------------------------------------------------------------- + +class TestInjectTriageSummary: + def test_injects_after_heading(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [ + { + "title": "SAST Findings", + "content": "\n# SAST Python Findings\n### Summary\nSome content\n", + } + ] + } + SecurityScanner._inject_triage_summary(notifications, 3, "https://socket.dev/scan/123") + + content = notifications["github_pr"][0]["content"] + assert "3 finding(s) triaged" in content + assert "Socket Dashboard" in content + lines = content.split("\n") + heading_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("# ")) + summary_idx = next(i for i, l in enumerate(lines) if "triaged" in l) + assert summary_idx > heading_idx + + def test_no_github_pr_key_is_noop(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = {"slack": [{"title": "t", "content": "c"}]} + SecurityScanner._inject_triage_summary(notifications, 5, "") + assert "github_pr" not in notifications + + def test_uses_default_dashboard_link(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [{"title": "t", "content": "# Title\nBody"}] + } + SecurityScanner._inject_triage_summary(notifications, 1, "") + assert "https://socket.dev/dashboard" in notifications["github_pr"][0]["content"]