diff --git a/spatialagent/tool/databases.py b/spatialagent/tool/databases.py index 377f835..98321e4 100644 --- a/spatialagent/tool/databases.py +++ b/spatialagent/tool/databases.py @@ -161,14 +161,21 @@ def search_panglao( @tool def search_czi_datasets( - query: Annotated[str, Field(description="Query describing tissue, condition, and organism (e.g., 'liver, breast cancer, Homo sapiens')")], + query: Annotated[str, Field(description="Query describing the condition or context (e.g., 'breast cancer', 'normal development'). Do NOT include organism or tissue here — use the dedicated parameters instead.")], n_datasets: Annotated[int, Field(description="Number of top datasets to return")] = 1, - organism: Annotated[str, Field(description="Filter by organism (e.g., 'Mus musculus', 'Homo sapiens'). Optional.")] = None, - tissue: Annotated[str, Field(description="Filter by tissue keyword (e.g., 'lung', 'brain'). Matches against tissue and tissue_general columns. Optional.")] = None, + organism: Annotated[str, Field(description="Filter by organism (e.g., 'Mus musculus', 'Homo sapiens'). Always pass this when the organism is known.")] = None, + tissue: Annotated[str, Field(description="Filter by tissue keyword (e.g., 'lung', 'brain'). Matches against tissue and tissue_general columns. Always pass this when the tissue is known.")] = None, ) -> str: """Search CZI CELLxGENE Census for reference single-cell datasets. - Input is a string containing: tissue, condition, and organism. + When the query mentions an organism or tissue, pass them as the dedicated + ``organism`` and ``tissue`` parameters so they are used as hard pre-filters + before embedding-based ranking. The ``query`` should contain only the + condition or context (e.g., 'breast cancer', 'normal development'). + + If strict filtering returns fewer than ``n_datasets`` results, the function + progressively relaxes filters (drop tissue first, then organism) and + includes a warning in the output. Returns: str: Formatted string with dataset info. Example output: @@ -206,22 +213,55 @@ def search_czi_datasets( df = pd.read_csv(metadata_path) - # Pre-filter by organism and tissue before embedding search - if organism: - mask = df["organism"] == organism - df_filtered = df[mask] - if len(df_filtered) >= n_datasets: - df = df_filtered - - if tissue: - tissue_lower = tissue.lower() - mask = ( - df["tissue"].str.lower().str.contains(tissue_lower, na=False) | - df["tissue_general"].str.lower().str.contains(tissue_lower, na=False) - ) - df_filtered = df[mask] - if len(df_filtered) >= n_datasets: - df = df_filtered + # Pre-filter by organism and tissue with controlled relaxation + filter_warnings = [] + applied_organism = None + applied_tissue = None + + if organism or tissue: + # Try strict filter: both organism and tissue + df_strict = df + if organism: + df_strict = df_strict[df_strict["organism"] == organism] + if tissue: + tissue_lower = tissue.lower() + df_strict = df_strict[ + df_strict["tissue"].str.lower().str.contains(tissue_lower, na=False) + | df_strict["tissue_general"].str.lower().str.contains(tissue_lower, na=False) + ] + + if len(df_strict) >= n_datasets: + df = df_strict + applied_organism = organism + applied_tissue = tissue + elif organism and tissue: + # Relax: drop tissue, keep organism only + df_org_only = df[df["organism"] == organism] + if len(df_org_only) >= n_datasets: + df = df_org_only + applied_organism = organism + filter_warnings.append( + f"Note: Only {len(df_strict)} datasets matched " + f"organism='{organism}' AND tissue='{tissue}' " + f"(fewer than {n_datasets} requested). " + f"Relaxed to organism='{organism}' only " + f"({len(df_org_only)} datasets)." + ) + else: + filter_warnings.append( + f"Note: Only {len(df_org_only)} datasets matched " + f"organism='{organism}' (fewer than {n_datasets} " + f"requested). Using all {len(df)} datasets." + ) + else: + # Single filter with too few results + filter_name = "organism" if organism else "tissue" + filter_val = organism if organism else tissue + filter_warnings.append( + f"Note: Only {len(df_strict)} datasets matched " + f"{filter_name}='{filter_val}' (fewer than {n_datasets} " + f"requested). Using all {len(df)} datasets." + ) df = df.reset_index(drop=True) @@ -239,12 +279,12 @@ def search_czi_datasets( query_embedding = _embed_with_retry(llm_embed_query, [query]) # Check for cached description embeddings (use effective_model for cache key) - # Include filter params in database identifier to avoid cache collisions + # Include actually-applied filters in identifier to avoid cache collisions db_id = "czi_census" - if organism: - db_id += f"_{organism.replace(' ', '_')}" - if tissue: - db_id += f"_{tissue.lower()}" + if applied_organism: + db_id += f"_{applied_organism.replace(' ', '_')}" + if applied_tissue: + db_id += f"_{applied_tissue.lower()}" cache_key = _get_cache_key(db_id, effective_model, len(descriptions)) desc_embeddings = _load_cached_embeddings(cache_key) @@ -276,6 +316,8 @@ def search_czi_datasets( results.append(result_str) output = "\n".join(results) + if filter_warnings: + output = "\n".join(filter_warnings) + "\n\n" + output return output