From fc09cfe6a1eccd56c570ee5bdfc30a4b760f8835 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 22 May 2026 17:47:31 -0500 Subject: [PATCH] Refactor search pipeline orchestration --- .../endpoint_modules/resources/thumbnail.py | 2 + .../app/api/v1/endpoint_modules/thumbnails.py | 2 + backend/app/elasticsearch/search.py | 1004 ++++++++++------- .../v1/test_resource_thumbnail_endpoints.py | 30 +- .../tests/api/v1/test_thumbnail_endpoints.py | 30 +- 5 files changed, 667 insertions(+), 401 deletions(-) diff --git a/backend/app/api/v1/endpoint_modules/resources/thumbnail.py b/backend/app/api/v1/endpoint_modules/resources/thumbnail.py index 801882a1..c2a38c36 100644 --- a/backend/app/api/v1/endpoint_modules/resources/thumbnail.py +++ b/backend/app/api/v1/endpoint_modules/resources/thumbnail.py @@ -105,6 +105,8 @@ async def _fast_thumbnail_alias_redirect(resource_id: str) -> RedirectResponse | image_hash = await _current_hot_thumbnail_hash_for_resource(resource_id) if not image_hash: return None + if not await _thumbnail_hash_has_cached_image(image_hash): + return None return _thumbnail_asset_redirect(image_hash) diff --git a/backend/app/api/v1/endpoint_modules/thumbnails.py b/backend/app/api/v1/endpoint_modules/thumbnails.py index 60daff68..00a6487c 100644 --- a/backend/app/api/v1/endpoint_modules/thumbnails.py +++ b/backend/app/api/v1/endpoint_modules/thumbnails.py @@ -35,6 +35,8 @@ async def _get_resource_alias_redirect(resource_id: str) -> Response | None: return None if not image_hash: return None + if not await _thumbnail_hash_has_cached_image(image_hash): + return None return Response( status_code=302, diff --git a/backend/app/elasticsearch/search.py b/backend/app/elasticsearch/search.py index 86919284..858fee60 100644 --- a/backend/app/elasticsearch/search.py +++ b/backend/app/elasticsearch/search.py @@ -4,6 +4,7 @@ import os import re import time +from dataclasses import dataclass from typing import Optional from urllib.parse import urlencode @@ -1084,42 +1085,233 @@ def _build_geospatial_filter(geo_params: dict) -> dict | None: return None -async def search_resources( - query: str = None, - fq: dict = None, - skip: int = 0, - limit: int = 20, - sort: list = None, - search_fields: str | None = None, - include_filters: dict | None = None, - exclude_filters: dict | None = None, - facets: Optional[str] = None, - adv_q: Optional[list] = None, - hydrate_hits: bool = True, -): - """Search resources in Elasticsearch with optional filters, sorting, and spelling - suggestions.""" - # Ensure limit is not zero to avoid division by zero errors - if limit <= 0: - limit = 20 # Default to 20 if limit is zero or negative +@dataclass +class SearchParams: + """Normalized inputs for a resource search request.""" + + query: str | None = None + fq: dict | None = None + skip: int = 0 + limit: int = 20 + sort: list | None = None + search_fields: str | None = None + include_filters: dict | None = None + exclude_filters: dict | None = None + facets: str | None = None + adv_q: list | None = None + hydrate_hits: bool = True + index_name: str = "" + + @classmethod + def from_inputs( + cls, + *, + query: str | None, + fq: dict | None, + skip: int, + limit: int, + sort: list | None, + search_fields: str | None, + include_filters: dict | None, + exclude_filters: dict | None, + facets: str | None, + adv_q: list | None, + hydrate_hits: bool, + ) -> "SearchParams": + normalized_limit = limit if limit > 0 else 20 + return cls( + query=query, + fq=fq, + skip=skip, + limit=normalized_limit, + sort=sort, + search_fields=search_fields, + include_filters=include_filters, + exclude_filters=exclude_filters, + facets=facets, + adv_q=adv_q, + hydrate_hits=hydrate_hits, + index_name=os.getenv("ELASTICSEARCH_INDEX", "btaa_geospatial_api"), + ) - index_name = os.getenv("ELASTICSEARCH_INDEX", "btaa_geospatial_api") - overall_start = time.perf_counter() + @property + def sort_clause(self) -> list: + return self.sort or [{"_score": "desc"}] - try: - # Get the current search criteria - search_criteria = get_search_criteria(query, fq, skip, limit, sort) - logger.debug(f"Search criteria: {search_criteria}") + def criteria(self) -> dict: + return get_search_criteria(self.query, self.fq, self.skip, self.limit, self.sort) + + +@dataclass +class SearchFacetSelection: + aggregations: dict + names: tuple[str, ...] + cache_key: str | None = None + cached_aggregations: dict | None = None + cache_status: str = "disabled" + cache_lookup_ms: float = 0.0 + cache_store_ms: float = 0.0 + + +class FacetService: + """Selects and caches search facet aggregations.""" + + async def prepare(self, params: SearchParams, search_criteria: dict) -> SearchFacetSelection: + allowed_aggs = None + if params.facets: + allowed_aggs = {f.strip() for f in params.facets.split(",") if f.strip()} + + full_aggs = _build_search_aggregations() + selected_aggs = ( + {k: v for k, v in full_aggs.items() if k in allowed_aggs} if allowed_aggs else full_aggs + ) + selected_agg_names = tuple(selected_aggs.keys()) + selection = SearchFacetSelection( + aggregations=selected_aggs, + names=selected_agg_names, + ) + + if not selected_aggs: + return selection + + selection.cache_key = _build_search_facet_cache_key( + index_name=params.index_name, + query=search_criteria.get("query"), + search_fields=params.search_fields, + fq=params.fq, + include_filters=params.include_filters, + exclude_filters=params.exclude_filters, + adv_q=params.adv_q, + selected_aggs=selected_agg_names, + ) + facet_cache_lookup_start = time.perf_counter() + selection.cached_aggregations = await _get_cached_search_aggregations(selection.cache_key) + selection.cache_lookup_ms = (time.perf_counter() - facet_cache_lookup_start) * 1000 + selection.cache_status = "hit" if selection.cached_aggregations is not None else "miss" + return selection + + async def apply_to_response( + self, + response_dict: dict, + selection: SearchFacetSelection, + ) -> None: + if selection.cached_aggregations is not None: + response_dict["aggregations"] = selection.cached_aggregations + return + + if not selection.cache_key or not selection.aggregations: + return + + facet_cache_store_start = time.perf_counter() + await _store_cached_search_aggregations( + selection.cache_key, + response_dict.get("aggregations", {}) or {}, + selection.names, + ) + selection.cache_store_ms = (time.perf_counter() - facet_cache_store_start) * 1000 + + +@dataclass +class SearchFilterPlan: + filter_clauses: list + must_not_clauses: list + bbox_filter_info: dict | None = None + + +@dataclass +class SearchQueryPlan: + search_query: dict + bool_query: dict + overlap_context: dict | None = None + + +class GeoFilterBuilder: + """Builds geospatial filters and optional bbox scoring context.""" + + def build(self, values: dict) -> tuple[dict | None, dict | None]: + logger.debug("Building geo filter from values: %s", values) + geo_filter = _build_geospatial_filter(values) + if not geo_filter: + logger.warning(f"Failed to build geo filter from values: {values}") + return None, None + + logger.debug("Geo filter built successfully: %s", geo_filter) + return geo_filter, self._bbox_filter_info(values) + + def _bbox_filter_info(self, values: dict) -> dict | None: + if not ( + values.get("type") == "bbox" and values.get("top_left") and values.get("bottom_right") + ): + return None + + bbox_bounds = _normalize_geo_bbox_bounds(values["top_left"], values["bottom_right"]) + if bbox_bounds and len(bbox_bounds["lon_ranges"]) == 1: + return { + "bounds": bbox_bounds, + "field": values.get("field", "dcat_centroid"), + "min_overlap_ratio": values.get("min_overlap_ratio"), + } + return None + + +class SearchQueryBuilder: + """Builds the Elasticsearch query body for resource search.""" + + def __init__( + self, + params: SearchParams, + search_criteria: dict, + facet_selection: SearchFacetSelection, + ): + self.params = params + self.search_criteria = search_criteria + self.facet_selection = facet_selection + self.geo_filter_builder = GeoFilterBuilder() + + def build(self) -> SearchQueryPlan: + filter_plan = self._build_filters() + must_clauses, should_clauses, combined_must_not = self._build_query_clauses( + filter_plan.must_not_clauses + ) + bool_query = self._build_bool_query( + filter_plan.filter_clauses, + must_clauses, + should_clauses, + combined_must_not, + ) + base_query, overlap_context = self._build_base_query( + bool_query, + filter_plan.filter_clauses, + filter_plan.bbox_filter_info, + ) + + search_query = { + **base_query, + "from": self.params.skip, + "size": self.params.limit, + "sort": self.params.sort_clause, + "track_total_hits": True, + } + if self.facet_selection.cached_aggregations is None and self.facet_selection.aggregations: + search_query["aggs"] = self.facet_selection.aggregations + + suggest = self._build_suggest() + if suggest: + search_query["suggest"] = suggest + + return SearchQueryPlan( + search_query=search_query, + bool_query=bool_query, + overlap_context=overlap_context, + ) - # Construct the filter query (legacy fq + new include/exclude) + def _build_filters(self) -> SearchFilterPlan: filter_clauses = [] must_not_clauses = [] - - # Track bbox filter for spatial scoring bbox_filter_info = None - if fq: - for field, values in fq.items(): + if self.params.fq: + for field, values in self.params.fq.items(): resolved_field = _resolve_filter_field(field) logger.debug( f"Processing filter - Field: {field}, " @@ -1130,69 +1322,30 @@ async def search_resources( else: filter_clauses.append({"term": {resolved_field: values}}) - if include_filters: - for field, values in include_filters.items(): + if self.params.include_filters: + for field, values in self.params.include_filters.items(): resolved_field = _resolve_filter_field(field) - # Handle geospatial queries if field == "geo" and isinstance(values, dict): - logger.debug("Building geo filter from values: %s", values) - geo_filter = _build_geospatial_filter(values) + geo_filter, geo_bbox_info = self.geo_filter_builder.build(values) if geo_filter: - logger.debug("Geo filter built successfully: %s", geo_filter) filter_clauses.append(geo_filter) - # Track bbox filter for spatial scoring - if ( - values.get("type") == "bbox" - and values.get("top_left") - and values.get("bottom_right") - ): - bbox_bounds = _normalize_geo_bbox_bounds( - values["top_left"], - values["bottom_right"], - ) - if bbox_bounds and len(bbox_bounds["lon_ranges"]) == 1: - bbox_filter_info = { - "bounds": bbox_bounds, - "field": values.get("field", "dcat_centroid"), - "min_overlap_ratio": values.get("min_overlap_ratio"), - } - else: - logger.warning(f"Failed to build geo filter from values: {values}") - # Handle year range queries + if geo_bbox_info: + bbox_filter_info = geo_bbox_info elif field == "year_range" and isinstance(values, dict): - # Expecting start and end keys - year_range_filter = {"range": {"gbl_indexYear_im": {}}} - if "start" in values: - try: - year_range_filter["range"]["gbl_indexYear_im"]["gte"] = int( - values["start"] - ) - except (ValueError, TypeError): - pass - if "end" in values: - try: - year_range_filter["range"]["gbl_indexYear_im"]["lte"] = int( - values["end"] - ) - except (ValueError, TypeError): - pass - - if year_range_filter["range"]["gbl_indexYear_im"]: + year_range_filter = self._build_year_range_filter(values) + if year_range_filter: filter_clauses.append(year_range_filter) - elif field in ("geo_global", "geo_or_near_global") and isinstance(values, list): if values and str(values[0]).lower() == "true": filter_clauses.append({"term": {resolved_field: True}}) elif isinstance(values, list): - # Use terms to match if ANY of the specified values are present - # This matches the behavior of legacy fq filters (OR logic) filter_clauses.append({"terms": {resolved_field: values}}) else: filter_clauses.append({"term": {resolved_field: values}}) - if exclude_filters: - for field, values in exclude_filters.items(): + if self.params.exclude_filters: + for field, values in self.params.exclude_filters.items(): resolved_field = _resolve_filter_field(field) if isinstance(values, list): @@ -1200,118 +1353,107 @@ async def search_resources( else: must_not_clauses.append({"term": {resolved_field: values}}) - # Optionally filter which aggs to include - allowed_aggs = None - if facets: - allowed_aggs = {f.strip() for f in facets.split(",") if f.strip()} + return SearchFilterPlan( + filter_clauses=filter_clauses, + must_not_clauses=must_not_clauses, + bbox_filter_info=bbox_filter_info, + ) - full_aggs = _build_search_aggregations() + def _build_year_range_filter(self, values: dict) -> dict | None: + year_range_filter = {"range": {"gbl_indexYear_im": {}}} + if "start" in values: + try: + year_range_filter["range"]["gbl_indexYear_im"]["gte"] = int(values["start"]) + except (ValueError, TypeError): + # Keep invalid year bounds permissive; ignore only the bad bound. + logger.debug("Ignoring invalid start year filter value: %r", values["start"]) + if "end" in values: + try: + year_range_filter["range"]["gbl_indexYear_im"]["lte"] = int(values["end"]) + except (ValueError, TypeError): + # Keep invalid year bounds permissive; ignore only the bad bound. + logger.debug("Ignoring invalid end year filter value: %r", values["end"]) - selected_aggs = ( - {k: v for k, v in full_aggs.items() if k in allowed_aggs} if allowed_aggs else full_aggs - ) - selected_agg_names = tuple(selected_aggs.keys()) - facet_cache_status = "disabled" - facet_cache_lookup_ms = 0.0 - facet_cache_store_ms = 0.0 - cached_aggregations = None - if selected_aggs: - facet_cache_key = _build_search_facet_cache_key( - index_name=index_name, - query=search_criteria.get("query"), - search_fields=search_fields, - fq=fq, - include_filters=include_filters, - exclude_filters=exclude_filters, - adv_q=adv_q, - selected_aggs=selected_agg_names, - ) - facet_cache_lookup_start = time.perf_counter() - cached_aggregations = await _get_cached_search_aggregations(facet_cache_key) - facet_cache_lookup_ms = (time.perf_counter() - facet_cache_lookup_start) * 1000 - facet_cache_status = "hit" if cached_aggregations is not None else "miss" - else: - facet_cache_key = None + if year_range_filter["range"]["gbl_indexYear_im"]: + return year_range_filter + return None - # Build the search query - # Support both q and adv_q simultaneously + def _build_query_clauses(self, must_not_clauses: list) -> tuple[list, list, list]: must_clauses = [] should_clauses = [] combined_must_not = list(must_not_clauses) - # Build query from q parameter if provided - query_value = search_criteria.get("query") + query_value = self.search_criteria.get("query") if query_value and query_value.strip(): - query_text = query_value.strip() - is_phrase = ( - len(query_text) >= 2 and query_text.startswith('"') and query_text.endswith('"') - ) - phrase = query_text[1:-1] if is_phrase else query_text - - # If specific fields are requested (and not 'all_fields'), - # use multi_match across provided fields - scoped = bool(search_fields) and search_fields.strip().lower() != "all_fields" - if scoped: - requested_fields = [f.strip() for f in search_fields.split(",") if f.strip()] - # Prefer exact matches via .keyword when available, - # but also search the analyzed field - expanded_fields = [] - for f in requested_fields: - expanded_fields.append(f) - expanded_fields.append(f"{f}.keyword") - - must_clauses.append( - { - "multi_match": { - "query": phrase, - "type": "best_fields" if not is_phrase else "phrase", - "operator": "AND", - "fields": expanded_fields, - } - } - ) - else: - # Default behavior across boosted fields using query_string - must_clauses.append( - { - "query_string": { - "query": _escape_query_string_brackets(query_text), - "fields": [ - "id^5", - "dct_title_s^3", - "dct_description_sm^2", - "summary^2", - "dct_creator_sm^2", - "dct_subject_sm^1.5", - "dcat_keyword_sm^1.5", - "dct_publisher_sm", - "schema_provider_s", - "dct_spatial_sm", - "gbl_displaynote_sm", - ], - "default_operator": "AND", - "analyze_wildcard": True, - "allow_leading_wildcard": True, - } - } - ) + must_clauses.append(self._build_text_query_clause(query_value.strip())) - # Build advanced query clauses if provided - if adv_q: - advanced_query_structure = _build_advanced_query(adv_q) - # Add advanced query AND clauses to must + if self.params.adv_q: + advanced_query_structure = _build_advanced_query(self.params.adv_q) must_clauses.extend(advanced_query_structure["must"]) - # Add advanced query OR clauses to should should_clauses.extend(advanced_query_structure["should"]) - # Add advanced query NOT clauses to must_not combined_must_not.extend(advanced_query_structure["must_not"]) - # Build the bool query combining all clauses - bool_query_dict = {} + return must_clauses, should_clauses, combined_must_not + + def _build_text_query_clause(self, query_text: str) -> dict: + is_phrase = len(query_text) >= 2 and query_text.startswith('"') and query_text.endswith('"') + phrase = query_text[1:-1] if is_phrase else query_text + scoped = ( + bool(self.params.search_fields) + and self.params.search_fields.strip().lower() != "all_fields" + ) + + if scoped: + requested_fields = [ + f.strip() for f in self.params.search_fields.split(",") if f.strip() + ] + expanded_fields = [] + for field_name in requested_fields: + expanded_fields.append(field_name) + expanded_fields.append(f"{field_name}.keyword") + + return { + "multi_match": { + "query": phrase, + "type": "best_fields" if not is_phrase else "phrase", + "operator": "AND", + "fields": expanded_fields, + } + } + + return { + "query_string": { + "query": _escape_query_string_brackets(query_text), + "fields": [ + "id^5", + "dct_title_s^3", + "dct_description_sm^2", + "summary^2", + "dct_creator_sm^2", + "dct_subject_sm^1.5", + "dcat_keyword_sm^1.5", + "dct_publisher_sm", + "schema_provider_s", + "dct_spatial_sm", + "gbl_displaynote_sm", + ], + "default_operator": "AND", + "analyze_wildcard": True, + "allow_leading_wildcard": True, + } + } + + def _build_bool_query( + self, + filter_clauses: list, + must_clauses: list, + should_clauses: list, + combined_must_not: list, + ) -> dict: + bool_query = {} - # Only include filter if there are filter clauses if filter_clauses: - bool_query_dict["filter"] = filter_clauses + bool_query["filter"] = filter_clauses logger.debug("Bool query - filter clauses count: %s", len(filter_clauses)) if filter_clauses: @@ -1321,64 +1463,61 @@ async def search_resources( logger.debug("No filter clauses found - query will return all results") if must_clauses: - bool_query_dict["must"] = must_clauses + bool_query["must"] = must_clauses elif not should_clauses: - # If no must clauses and no should clauses, match all - bool_query_dict["must"] = [{"match_all": {}}] + bool_query["must"] = [{"match_all": {}}] if should_clauses: - bool_query_dict["should"] = should_clauses - bool_query_dict["minimum_should_match"] = 1 + bool_query["should"] = should_clauses + bool_query["minimum_should_match"] = 1 if combined_must_not: - bool_query_dict["must_not"] = combined_must_not + bool_query["must_not"] = combined_must_not - # Base query is a plain bool; we will wrap it in script_score when we have - # bbox info for spatial reranking. - base_query = {"query": {"bool": bool_query_dict}} + return bool_query + + def _build_base_query( + self, + bool_query: dict, + filter_clauses: list, + bbox_filter_info: dict | None, + ) -> tuple[dict, dict | None]: + base_query = {"query": {"bool": bool_query}} overlap_context = None - # Add bbox spatial scoring when bbox filter is present. - # This combines document containment within the query bbox and IoU - # extent similarity using numeric bbox_* fields, and does NOT use - # centroids at all. - if bbox_filter_info: - bbox_bounds = bbox_filter_info["bounds"] - west, east = bbox_bounds["lon_ranges"][0] - - # Query bbox bounds (x = lon, y = lat) - q_minx = west - q_maxx = east - q_miny = bbox_bounds["south"] - q_maxy = bbox_bounds["north"] - - containment_weight, overlap_weight = _normalized_spatial_weights() - - # Persist query bbox bounds so we can later compute concrete bbox - # spatial metrics per hit in Python for the API meta block. - overlap_context = { - "qMinX": q_minx, - "qMaxX": q_maxx, - "qMinY": q_miny, - "qMaxY": q_maxy, - } - min_overlap_ratio = _normalize_min_overlap_ratio( - bbox_filter_info.get("min_overlap_ratio") - ) - filter_clauses.append( - _build_bbox_overlap_filter( - q_minx=q_minx, - q_maxx=q_maxx, - q_miny=q_miny, - q_maxy=q_maxy, - min_overlap_ratio=min_overlap_ratio, - ) + if not bbox_filter_info: + return base_query, overlap_context + + bbox_bounds = bbox_filter_info["bounds"] + west, east = bbox_bounds["lon_ranges"][0] + q_minx = west + q_maxx = east + q_miny = bbox_bounds["south"] + q_maxy = bbox_bounds["north"] + containment_weight, overlap_weight = _normalized_spatial_weights() + + overlap_context = { + "qMinX": q_minx, + "qMaxX": q_maxx, + "qMinY": q_miny, + "qMaxY": q_maxy, + } + min_overlap_ratio = _normalize_min_overlap_ratio(bbox_filter_info.get("min_overlap_ratio")) + filter_clauses.append( + _build_bbox_overlap_filter( + q_minx=q_minx, + q_maxx=q_maxx, + q_miny=q_miny, + q_maxy=q_maxy, + min_overlap_ratio=min_overlap_ratio, ) + ) - base_query = { + return ( + { "query": { "script_score": { - "query": {"bool": bool_query_dict}, + "query": {"bool": bool_query}, "script": { "source": """ // Read document bbox from numeric bbox_* fields @@ -1478,187 +1617,282 @@ async def search_resources( }, } } - } + }, + overlap_context, + ) - search_query = { - **base_query, - "from": skip, - "size": limit, - "sort": sort or [{"_score": "desc"}], - "track_total_hits": True, + def _build_suggest(self) -> dict | None: + query_text = self.search_criteria.get("query") + if not query_text or not query_text.strip(): + return None + + return { + "text": query_text, + "simple_phrase": { + "phrase": { + "field": "dct_title_s", + "size": 1, + "gram_size": 3, + "direct_generator": [ + {"field": "dct_title_s", "suggest_mode": "always"}, + {"field": "dct_description_sm", "suggest_mode": "always"}, + ], + "highlight": {"pre_tag": "", "post_tag": ""}, + } + }, } - if cached_aggregations is None and selected_aggs: - search_query["aggs"] = selected_aggs - - # Add suggestions if q parameter was provided - if search_criteria.get("query") and search_criteria["query"].strip(): - search_query["suggest"] = { - "text": search_criteria["query"], - "simple_phrase": { - "phrase": { - "field": "dct_title_s", - "size": 1, - "gram_size": 3, - "direct_generator": [ - {"field": "dct_title_s", "suggest_mode": "always"}, - {"field": "dct_description_sm", "suggest_mode": "always"}, - ], - "highlight": {"pre_tag": "", "post_tag": ""}, - } - }, - } - # If neither q nor adv_q provided, search_query already has match_all in must - if logger.isEnabledFor(logging.DEBUG): - logger.debug("ES Query: %s", json.dumps(search_query, indent=2)) - es_roundtrip_ms = 0.0 - - async def finalize_response( - response_dict: dict, - *, - source: str, - response_overlap_context: dict | None, - ) -> dict: - nonlocal facet_cache_store_ms - - if cached_aggregations is not None: - response_dict["aggregations"] = cached_aggregations - elif facet_cache_key and selected_aggs: - facet_cache_store_start = time.perf_counter() - await _store_cached_search_aggregations( - facet_cache_key, - response_dict.get("aggregations", {}) or {}, - selected_agg_names, - ) - facet_cache_store_ms = (time.perf_counter() - facet_cache_store_start) * 1000 - - result = await process_search_response( - response_dict, - limit, - skip, - search_criteria, - overlap_context=response_overlap_context, - include_filters=include_filters, - exclude_filters=exclude_filters, - adv_q=adv_q, - hydrate_hits=hydrate_hits, - ) - total_ms = (time.perf_counter() - overall_start) * 1000 - response_hits = response_dict.get("hits", {}) - total_hits_value = (response_hits.get("total") or {}).get("value") - _log_aggregation_timing( - operation="search_resources", - cache_status=facet_cache_status, - total_ms=total_ms, - es_roundtrip_ms=es_roundtrip_ms, - es_took_ms=float(response_dict.get("took", 0) or 0), - cache_lookup_ms=facet_cache_lookup_ms, - cache_store_ms=facet_cache_store_ms, - aggregation_names=selected_agg_names, - total_hits=total_hits_value if total_hits_value is not None else None, - hit_count=len(response_hits.get("hits", []) or []), - source=source, - ) - return result +@dataclass +class SearchExecutionResult: + response_dict: dict + source: str + overlap_context: dict | None + es_roundtrip_ms: float + +class SearchExecutor: + """Runs the Elasticsearch query and handles search-specific fallbacks.""" + + def __init__( + self, + params: SearchParams, + query_plan: SearchQueryPlan, + facet_selection: SearchFacetSelection, + ): + self.params = params + self.query_plan = query_plan + self.facet_selection = facet_selection + + async def execute(self) -> SearchExecutionResult: try: - # Call ES using keyword args so tests can inspect 'query' and 'suggest' - search_kwargs = { - "index": index_name, - "query": search_query["query"], - "from_": skip, - "size": limit, - "sort": sort or [{"_score": "desc"}], - "track_total_hits": True, - "suggest": search_query.get("suggest"), - } - if search_query.get("aggs"): - search_kwargs["aggs"] = search_query["aggs"] - es_roundtrip_start = time.perf_counter() - response = await es.search(**search_kwargs) - es_roundtrip_ms = (time.perf_counter() - es_roundtrip_start) * 1000 - response_dict = response.body if hasattr(response, "body") else response + response_dict, es_roundtrip_ms = await self._run_search(self._primary_kwargs()) + return SearchExecutionResult( + response_dict=response_dict, + source="primary", + overlap_context=self.query_plan.overlap_context, + es_roundtrip_ms=es_roundtrip_ms, + ) except NotFoundError: - # Index missing: return empty result structure instead of 500 - logger.warning(f"Elasticsearch index '{index_name}' not found; returning empty results") - empty_response = { - "hits": {"total": {"value": 0}, "hits": []}, - "took": 0, - "aggregations": {}, - } - return await finalize_response( - empty_response, + logger.warning( + "Elasticsearch index '%s' not found; returning empty results", + self.params.index_name, + ) + return SearchExecutionResult( + response_dict={ + "hits": {"total": {"value": 0}, "hits": []}, + "took": 0, + "aggregations": {}, + }, source="missing_index", - response_overlap_context=None, + overlap_context=None, + es_roundtrip_ms=0.0, ) except Exception as es_error: logger.error(f"Elasticsearch error: {str(es_error)}", exc_info=True) + fallback = await self._maybe_run_script_score_fallback(es_error) + if fallback is not None: + return fallback + raise self._build_elasticsearch_http_error(es_error) from es_error + + def _primary_kwargs(self) -> dict: + search_query = self.query_plan.search_query + search_kwargs = { + "index": self.params.index_name, + "query": search_query["query"], + "from_": self.params.skip, + "size": self.params.limit, + "sort": self.params.sort_clause, + "track_total_hits": True, + "suggest": search_query.get("suggest"), + } + if search_query.get("aggs"): + search_kwargs["aggs"] = search_query["aggs"] + return search_kwargs + + def _fallback_kwargs(self) -> dict: + search_query = self.query_plan.search_query + fallback_kwargs = { + "index": self.params.index_name, + "query": {"bool": self.query_plan.bool_query}, + "from_": self.params.skip, + "size": self.params.limit, + "sort": self.params.sort_clause, + "track_total_hits": True, + "suggest": search_query.get("suggest"), + } + if self.facet_selection.cached_aggregations is None and self.facet_selection.aggregations: + fallback_kwargs["aggs"] = self.facet_selection.aggregations + return fallback_kwargs - # If the failure is due to script_score (e.g. painless compile error), - # fall back to a plain bool query WITHOUT overlap scoring so we still - # return correct filtered results instead of zero. + async def _run_search(self, search_kwargs: dict) -> tuple[dict, float]: + es_roundtrip_start = time.perf_counter() + response = await es.search(**search_kwargs) + es_roundtrip_ms = (time.perf_counter() - es_roundtrip_start) * 1000 + response_dict = response.body if hasattr(response, "body") else response + return response_dict, es_roundtrip_ms + + async def _maybe_run_script_score_fallback( + self, + es_error: Exception, + ) -> SearchExecutionResult | None: + info = getattr(es_error, "info", {}) or {} + error_type = info.get("error", {}).get("root_cause", [{}])[0].get("type", "") + if "script_exception" not in error_type and "script_exception" not in str(es_error): + return None + + logger.warning( + "Script_score query failed (likely painless compile error); " + "falling back to plain bool query without overlap scoring." + ) + try: + fallback_dict, es_roundtrip_ms = await self._run_search(self._fallback_kwargs()) + return SearchExecutionResult( + response_dict=fallback_dict, + source="script_score_fallback", + overlap_context=None, + es_roundtrip_ms=es_roundtrip_ms, + ) + except Exception as fallback_error: + logger.error( + "Fallback bool query after script failure also errored: %s", + fallback_error, + exc_info=True, + ) + return None + + def _build_elasticsearch_http_error(self, es_error: Exception) -> HTTPException: + # Keep upstream query internals out of public 500 responses; the full + # exception is already logged with exc_info in execute(). + error_detail = { + "message": "Elasticsearch query failed", + "code": "elasticsearch_query_failed", + } + if hasattr(es_error, "info"): info = getattr(es_error, "info", {}) or {} - error_type = info.get("error", {}).get("root_cause", [{}])[0].get("type", "") - if "script_exception" in error_type or "script_exception" in str(es_error): - logger.warning( - "Script_score query failed (likely painless compile error); " - "falling back to plain bool query without overlap scoring." - ) - try: - fallback_kwargs = { - "index": index_name, - "query": {"bool": bool_query_dict}, - "from_": skip, - "size": limit, - "sort": sort or [{"_score": "desc"}], - "track_total_hits": True, - "suggest": search_query.get("suggest"), - } - if cached_aggregations is None and selected_aggs: - fallback_kwargs["aggs"] = selected_aggs - es_roundtrip_start = time.perf_counter() - fallback_response = await es.search(**fallback_kwargs) - es_roundtrip_ms = (time.perf_counter() - es_roundtrip_start) * 1000 - fallback_dict = ( - fallback_response.body - if hasattr(fallback_response, "body") - else fallback_response - ) - return await finalize_response( - fallback_dict, - source="script_score_fallback", - response_overlap_context=None, - ) - except Exception as fallback_error: - logger.error( - f"Fallback bool query after script failure also errored: {fallback_error}", - exc_info=True, - ) + upstream_status = info.get("status") if isinstance(info, dict) else None + if isinstance(upstream_status, int): + error_detail["upstream_status_code"] = upstream_status + if hasattr(es_error, "status_code"): + status_code = es_error.status_code + if isinstance(status_code, int): + error_detail["upstream_status_code"] = status_code + return HTTPException(status_code=500, detail=error_detail) + + +class SearchResponseBuilder: + """Turns an ES response into the existing search_resources payload.""" + + def __init__( + self, + params: SearchParams, + search_criteria: dict, + facet_service: FacetService, + facet_selection: SearchFacetSelection, + overall_start: float, + ): + self.params = params + self.search_criteria = search_criteria + self.facet_service = facet_service + self.facet_selection = facet_selection + self.overall_start = overall_start + + async def build(self, execution: SearchExecutionResult) -> dict: + await self.facet_service.apply_to_response( + execution.response_dict, + self.facet_selection, + ) + result = await process_search_response( + execution.response_dict, + self.params.limit, + self.params.skip, + self.search_criteria, + overlap_context=execution.overlap_context, + include_filters=self.params.include_filters, + exclude_filters=self.params.exclude_filters, + adv_q=self.params.adv_q, + hydrate_hits=self.params.hydrate_hits, + ) + self._log_timing(execution) + return result - # If we get here, propagate a public-safe HTTP error. The full query, - # index name, and upstream exception remain in logs via exc_info above. - error_detail = { - "message": "Elasticsearch query failed", - "code": "elasticsearch_query_failed", - } - if hasattr(es_error, "info"): - info = getattr(es_error, "info", {}) or {} - upstream_status = info.get("status") if isinstance(info, dict) else None - if isinstance(upstream_status, int): - error_detail["upstream_status_code"] = upstream_status - if hasattr(es_error, "status_code"): - status_code = es_error.status_code - if isinstance(status_code, int): - error_detail["upstream_status_code"] = status_code - raise HTTPException(status_code=500, detail=error_detail) from es_error - - return await finalize_response( - response_dict, - source="primary", - response_overlap_context=overlap_context, + def _log_timing(self, execution: SearchExecutionResult) -> None: + total_ms = (time.perf_counter() - self.overall_start) * 1000 + response_hits = execution.response_dict.get("hits", {}) + total_hits_value = (response_hits.get("total") or {}).get("value") + _log_aggregation_timing( + operation="search_resources", + cache_status=self.facet_selection.cache_status, + total_ms=total_ms, + es_roundtrip_ms=execution.es_roundtrip_ms, + es_took_ms=float(execution.response_dict.get("took", 0) or 0), + cache_lookup_ms=self.facet_selection.cache_lookup_ms, + cache_store_ms=self.facet_selection.cache_store_ms, + aggregation_names=self.facet_selection.names, + total_hits=total_hits_value if total_hits_value is not None else None, + hit_count=len(response_hits.get("hits", []) or []), + source=execution.source, ) + +async def search_resources( + query: str = None, + fq: dict = None, + skip: int = 0, + limit: int = 20, + sort: list = None, + search_fields: str | None = None, + include_filters: dict | None = None, + exclude_filters: dict | None = None, + facets: Optional[str] = None, + adv_q: Optional[list] = None, + hydrate_hits: bool = True, +): + """Search resources in Elasticsearch with optional filters, sorting, and spelling + suggestions.""" + params = SearchParams.from_inputs( + query=query, + fq=fq, + skip=skip, + limit=limit, + sort=sort, + search_fields=search_fields, + include_filters=include_filters, + exclude_filters=exclude_filters, + facets=facets, + adv_q=adv_q, + hydrate_hits=hydrate_hits, + ) + overall_start = time.perf_counter() + + try: + search_criteria = params.criteria() + logger.debug(f"Search criteria: {search_criteria}") + + facet_service = FacetService() + facet_selection = await facet_service.prepare(params, search_criteria) + query_plan = SearchQueryBuilder( + params, + search_criteria, + facet_selection, + ).build() + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("ES Query: %s", json.dumps(query_plan.search_query, indent=2)) + + execution = await SearchExecutor( + params, + query_plan, + facet_selection, + ).execute() + return await SearchResponseBuilder( + params, + search_criteria, + facet_service, + facet_selection, + overall_start, + ).build(execution) + except Exception as e: logger.error(f"Search documents error: {str(e)}", exc_info=True) raise diff --git a/backend/tests/api/v1/test_resource_thumbnail_endpoints.py b/backend/tests/api/v1/test_resource_thumbnail_endpoints.py index 083f7927..c41b5331 100644 --- a/backend/tests/api/v1/test_resource_thumbnail_endpoints.py +++ b/backend/tests/api/v1/test_resource_thumbnail_endpoints.py @@ -86,31 +86,45 @@ def test_resource_thumbnail_alias_redirect_short_circuits(self, client): resource_id = "test-fast-alias" image_hash = "e7810cca426f65fa9e5e25124ca1b213b6c54deec0901c88805558faa7e25639" - with patch( - "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", - new=AsyncMock(return_value=image_hash), - ) as mock_current_hash: + with ( + patch( + "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", + new=AsyncMock(return_value=image_hash), + ) as mock_current_hash, + patch( + "app.api.v1.endpoint_modules.resources.thumbnail._thumbnail_hash_has_cached_image", + new=AsyncMock(return_value=True), + ) as mock_hash_cached, + ): response = client.get(f"/resources/{resource_id}/thumbnail", follow_redirects=False) assert response.status_code == 302 assert response.headers["location"] == f"/api/v1/thumbnails/{image_hash}" assert "max-age=3600" in response.headers["cache-control"] mock_current_hash.assert_awaited_once_with(resource_id) + mock_hash_cached.assert_awaited_once_with(image_hash) def test_resource_thumbnail_success_state_rehydrates_alias_redirect(self, client): """The canonical resolver should redirect when the current source is hot.""" resource_id = "test-success-state" image_hash = "e7810cca426f65fa9e5e25124ca1b213b6c54deec0901c88805558faa7e25639" - with patch( - "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", - new=AsyncMock(return_value=image_hash), - ) as mock_current_hash: + with ( + patch( + "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", + new=AsyncMock(return_value=image_hash), + ) as mock_current_hash, + patch( + "app.api.v1.endpoint_modules.resources.thumbnail._thumbnail_hash_has_cached_image", + new=AsyncMock(return_value=True), + ) as mock_hash_cached, + ): response = client.get(f"/resources/{resource_id}/thumbnail", follow_redirects=False) assert response.status_code == 302 assert response.headers["location"] == f"/api/v1/thumbnails/{image_hash}" mock_current_hash.assert_awaited_once_with(resource_id) + mock_hash_cached.assert_awaited_once_with(image_hash) def test_resource_thumbnail_stale_alias_falls_through_to_resolver(self, client): """Aliases pointing at missing image bytes should not pin placeholders.""" diff --git a/backend/tests/api/v1/test_thumbnail_endpoints.py b/backend/tests/api/v1/test_thumbnail_endpoints.py index 16a50f0f..986de59a 100644 --- a/backend/tests/api/v1/test_thumbnail_endpoints.py +++ b/backend/tests/api/v1/test_thumbnail_endpoints.py @@ -117,31 +117,45 @@ def test_get_thumbnail_resource_alias_redirect(self, client): resource_id = "nyu-2451-34564" image_hash = "e7810cca426f65fa9e5e25124ca1b213b6c54deec0901c88805558faa7e25639" - with patch( - "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", - new=AsyncMock(return_value=image_hash), - ) as mock_current_hash: + with ( + patch( + "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", + new=AsyncMock(return_value=image_hash), + ) as mock_current_hash, + patch( + "app.api.v1.endpoint_modules.thumbnails._thumbnail_hash_has_cached_image", + new=AsyncMock(return_value=True), + ) as mock_hash_cached, + ): response = client.get(f"/thumbnails/{resource_id}", follow_redirects=False) assert response.status_code == 302 assert response.headers["location"] == f"/api/v1/thumbnails/{image_hash}" assert "max-age=3600" in response.headers["cache-control"] mock_current_hash.assert_awaited_once_with(resource_id) + mock_hash_cached.assert_awaited_once_with(image_hash) def test_get_thumbnail_success_state_rehydrates_alias_redirect(self, client): """Resource-id requests should redirect when the canonical source is hot.""" resource_id = "nyu-2451-34564" image_hash = "e7810cca426f65fa9e5e25124ca1b213b6c54deec0901c88805558faa7e25639" - with patch( - "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", - new=AsyncMock(return_value=image_hash), - ) as mock_current_hash: + with ( + patch( + "app.api.v1.endpoint_modules.resources.thumbnail._current_hot_thumbnail_hash_for_resource", + new=AsyncMock(return_value=image_hash), + ) as mock_current_hash, + patch( + "app.api.v1.endpoint_modules.thumbnails._thumbnail_hash_has_cached_image", + new=AsyncMock(return_value=True), + ) as mock_hash_cached, + ): response = client.get(f"/thumbnails/{resource_id}", follow_redirects=False) assert response.status_code == 302 assert response.headers["location"] == f"/api/v1/thumbnails/{image_hash}" mock_current_hash.assert_awaited_once_with(resource_id) + mock_hash_cached.assert_awaited_once_with(image_hash) def test_get_thumbnail_stale_resource_alias_falls_through_to_resolver(self, client): """Resource-id aliases with missing image bytes should not redirect to placeholders."""