From 4d383e5279820cb3696303f09fbe671f7a9318e6 Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Wed, 18 Mar 2026 16:43:33 -0500 Subject: [PATCH] fix: add structured metadata filtering to CZI dataset search (fixes #5) Update search_czi_datasets docstring and parameter descriptions to instruct the agent to pass organism and tissue as dedicated parameters instead of packing them into the query string. Replace silent filter drop with controlled relaxation (organism+tissue -> organism-only -> unfiltered) that includes a warning in the output when filters are relaxed. --- spatialagent/tool/databases.py | 92 +++++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 25 deletions(-) diff --git a/spatialagent/tool/databases.py b/spatialagent/tool/databases.py index 377f835..98321e4 100644 --- a/spatialagent/tool/databases.py +++ b/spatialagent/tool/databases.py @@ -161,14 +161,21 @@ def search_panglao( @tool def search_czi_datasets( - query: Annotated[str, Field(description="Query describing tissue, condition, and organism (e.g., 'liver, breast cancer, Homo sapiens')")], + query: Annotated[str, Field(description="Query describing the condition or context (e.g., 'breast cancer', 'normal development'). Do NOT include organism or tissue here — use the dedicated parameters instead.")], n_datasets: Annotated[int, Field(description="Number of top datasets to return")] = 1, - organism: Annotated[str, Field(description="Filter by organism (e.g., 'Mus musculus', 'Homo sapiens'). Optional.")] = None, - tissue: Annotated[str, Field(description="Filter by tissue keyword (e.g., 'lung', 'brain'). Matches against tissue and tissue_general columns. Optional.")] = None, + organism: Annotated[str, Field(description="Filter by organism (e.g., 'Mus musculus', 'Homo sapiens'). Always pass this when the organism is known.")] = None, + tissue: Annotated[str, Field(description="Filter by tissue keyword (e.g., 'lung', 'brain'). Matches against tissue and tissue_general columns. Always pass this when the tissue is known.")] = None, ) -> str: """Search CZI CELLxGENE Census for reference single-cell datasets. - Input is a string containing: tissue, condition, and organism. + When the query mentions an organism or tissue, pass them as the dedicated + ``organism`` and ``tissue`` parameters so they are used as hard pre-filters + before embedding-based ranking. The ``query`` should contain only the + condition or context (e.g., 'breast cancer', 'normal development'). + + If strict filtering returns fewer than ``n_datasets`` results, the function + progressively relaxes filters (drop tissue first, then organism) and + includes a warning in the output. Returns: str: Formatted string with dataset info. Example output: @@ -206,22 +213,55 @@ def search_czi_datasets( df = pd.read_csv(metadata_path) - # Pre-filter by organism and tissue before embedding search - if organism: - mask = df["organism"] == organism - df_filtered = df[mask] - if len(df_filtered) >= n_datasets: - df = df_filtered - - if tissue: - tissue_lower = tissue.lower() - mask = ( - df["tissue"].str.lower().str.contains(tissue_lower, na=False) | - df["tissue_general"].str.lower().str.contains(tissue_lower, na=False) - ) - df_filtered = df[mask] - if len(df_filtered) >= n_datasets: - df = df_filtered + # Pre-filter by organism and tissue with controlled relaxation + filter_warnings = [] + applied_organism = None + applied_tissue = None + + if organism or tissue: + # Try strict filter: both organism and tissue + df_strict = df + if organism: + df_strict = df_strict[df_strict["organism"] == organism] + if tissue: + tissue_lower = tissue.lower() + df_strict = df_strict[ + df_strict["tissue"].str.lower().str.contains(tissue_lower, na=False) + | df_strict["tissue_general"].str.lower().str.contains(tissue_lower, na=False) + ] + + if len(df_strict) >= n_datasets: + df = df_strict + applied_organism = organism + applied_tissue = tissue + elif organism and tissue: + # Relax: drop tissue, keep organism only + df_org_only = df[df["organism"] == organism] + if len(df_org_only) >= n_datasets: + df = df_org_only + applied_organism = organism + filter_warnings.append( + f"Note: Only {len(df_strict)} datasets matched " + f"organism='{organism}' AND tissue='{tissue}' " + f"(fewer than {n_datasets} requested). " + f"Relaxed to organism='{organism}' only " + f"({len(df_org_only)} datasets)." + ) + else: + filter_warnings.append( + f"Note: Only {len(df_org_only)} datasets matched " + f"organism='{organism}' (fewer than {n_datasets} " + f"requested). Using all {len(df)} datasets." + ) + else: + # Single filter with too few results + filter_name = "organism" if organism else "tissue" + filter_val = organism if organism else tissue + filter_warnings.append( + f"Note: Only {len(df_strict)} datasets matched " + f"{filter_name}='{filter_val}' (fewer than {n_datasets} " + f"requested). Using all {len(df)} datasets." + ) df = df.reset_index(drop=True) @@ -239,12 +279,12 @@ def search_czi_datasets( query_embedding = _embed_with_retry(llm_embed_query, [query]) # Check for cached description embeddings (use effective_model for cache key) - # Include filter params in database identifier to avoid cache collisions + # Include actually-applied filters in identifier to avoid cache collisions db_id = "czi_census" - if organism: - db_id += f"_{organism.replace(' ', '_')}" - if tissue: - db_id += f"_{tissue.lower()}" + if applied_organism: + db_id += f"_{applied_organism.replace(' ', '_')}" + if applied_tissue: + db_id += f"_{applied_tissue.lower()}" cache_key = _get_cache_key(db_id, effective_model, len(descriptions)) desc_embeddings = _load_cached_embeddings(cache_key) @@ -276,6 +316,8 @@ def search_czi_datasets( results.append(result_str) output = "\n".join(results) + if filter_warnings: + output = "\n".join(filter_warnings) + "\n\n" + output return output