diff --git a/src/groundlight/experimental_api.py b/src/groundlight/experimental_api.py index 67bc6648..60c568c4 100644 --- a/src/groundlight/experimental_api.py +++ b/src/groundlight/experimental_api.py @@ -8,6 +8,7 @@ """ import json +import time from io import BufferedReader, BytesIO from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -40,6 +41,7 @@ ) from urllib3.response import HTTPResponse +from groundlight.edge.config import EdgeEndpointConfig from groundlight.images import parse_supported_image_types from groundlight.internalapi import _generate_request_id from groundlight.optional_imports import Image, np @@ -817,3 +819,72 @@ def make_generic_api_request( # noqa: PLR0913 # pylint: disable=too-many-argume auth_settings=["ApiToken"], _preload_content=False, # This returns the urllib3 response rather than trying any type of processing ) + + def _edge_base_url(self) -> str: + """Return the scheme+host+port of the configured endpoint, without the /device-api path.""" + from urllib.parse import urlparse, urlunparse + + parsed = urlparse(self.configuration.host) + return urlunparse((parsed.scheme, parsed.netloc, "", "", "", "")) + + def get_edge_config(self) -> EdgeEndpointConfig: + """Retrieve the active edge endpoint configuration. + + Only works when the client is pointed at an edge endpoint + (via GROUNDLIGHT_ENDPOINT or the endpoint constructor arg). + """ + url = f"{self._edge_base_url()}/edge-config" + headers = self.get_raw_headers() + response = requests.get(url, headers=headers, verify=self.configuration.verify_ssl) + response.raise_for_status() + return EdgeEndpointConfig.from_payload(response.json()) + + def get_edge_detector_readiness(self) -> dict[str, bool]: + """Check which configured detectors have inference pods ready to serve. + + Only works when the client is pointed at an edge endpoint. + + :return: Dict mapping detector_id to readiness (True/False). + """ + url = f"{self._edge_base_url()}/edge-detector-readiness" + headers = self.get_raw_headers() + response = requests.get(url, headers=headers, verify=self.configuration.verify_ssl) + response.raise_for_status() + return {det_id: info["ready"] for det_id, info in response.json().items()} + + def set_edge_config( + self, + config: EdgeEndpointConfig, + mode: str = "REPLACE", + timeout_sec: float = 300, + poll_interval_sec: float = 1, + ) -> EdgeEndpointConfig: + """Send a new edge endpoint configuration and wait until all detectors are ready. + + Only works when the client is pointed at an edge endpoint. + + :param config: The new configuration to apply. + :param mode: Currently only "REPLACE" is supported. + :param timeout_sec: Max seconds to wait for all detectors to become ready. + :param poll_interval_sec: How often to poll readiness while waiting. + :return: The applied configuration as reported by the edge endpoint. + """ + if mode != "REPLACE": + raise ValueError(f"Unsupported mode: {mode!r}. Currently only 'REPLACE' is supported.") + + url = f"{self._edge_base_url()}/edge-config" + headers = self.get_raw_headers() + response = requests.put(url, json=config.to_payload(), headers=headers, verify=self.configuration.verify_ssl) + response.raise_for_status() + + desired_ids = {d.detector_id for d in config.detectors if d.detector_id} + deadline = time.time() + timeout_sec + while time.time() < deadline: + readiness = self.get_edge_detector_readiness() + if desired_ids and all(readiness.get(did, False) for did in desired_ids): + return self.get_edge_config() + time.sleep(poll_interval_sec) + + raise TimeoutError( + f"Edge detectors were not all ready within {timeout_sec}s. The edge endpoint may still be converging." + ) diff --git a/test/unit/test_edge_config.py b/test/unit/test_edge_config.py index 469e6061..a3e3b3db 100644 --- a/test/unit/test_edge_config.py +++ b/test/unit/test_edge_config.py @@ -307,3 +307,30 @@ def test_inference_config_validation_errors(): always_return_edge_prediction=True, min_time_between_escalations=-1.0, ) + + +def test_get_edge_config_parses_response(): + """ExperimentalApi.get_edge_config() parses the HTTP response into an EdgeEndpointConfig.""" + from unittest.mock import Mock, patch + + from groundlight import ExperimentalApi + + payload = { + "global_config": {"refresh_rate": REFRESH_RATE_SECONDS}, + "edge_inference_configs": {"default": {"enabled": True}}, + "detectors": [{"detector_id": "det_1", "edge_inference_config": "default"}], + } + + mock_response = Mock() + mock_response.json.return_value = payload + mock_response.raise_for_status = Mock() + + gl = ExperimentalApi() + with patch("requests.get", return_value=mock_response) as mock_get: + config = gl.get_edge_config() + + mock_get.assert_called_once() + assert isinstance(config, EdgeEndpointConfig) + assert config.global_config.refresh_rate == REFRESH_RATE_SECONDS + assert config.edge_inference_configs["default"].name == "default" + assert [d.detector_id for d in config.detectors] == ["det_1"]