Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 67 additions & 25 deletions spatialagent/tool/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,21 @@ def search_panglao(

@tool
def search_czi_datasets(
query: Annotated[str, Field(description="Query describing tissue, condition, and organism (e.g., 'liver, breast cancer, Homo sapiens')")],
query: Annotated[str, Field(description="Query describing the condition or context (e.g., 'breast cancer', 'normal development'). Do NOT include organism or tissue here — use the dedicated parameters instead.")],
n_datasets: Annotated[int, Field(description="Number of top datasets to return")] = 1,
organism: Annotated[str, Field(description="Filter by organism (e.g., 'Mus musculus', 'Homo sapiens'). Optional.")] = None,
tissue: Annotated[str, Field(description="Filter by tissue keyword (e.g., 'lung', 'brain'). Matches against tissue and tissue_general columns. Optional.")] = None,
organism: Annotated[str, Field(description="Filter by organism (e.g., 'Mus musculus', 'Homo sapiens'). Always pass this when the organism is known.")] = None,
tissue: Annotated[str, Field(description="Filter by tissue keyword (e.g., 'lung', 'brain'). Matches against tissue and tissue_general columns. Always pass this when the tissue is known.")] = None,
) -> str:
"""Search CZI CELLxGENE Census for reference single-cell datasets.

Input is a string containing: tissue, condition, and organism.
When the query mentions an organism or tissue, pass them as the dedicated
``organism`` and ``tissue`` parameters so they are used as hard pre-filters
before embedding-based ranking. The ``query`` should contain only the
condition or context (e.g., 'breast cancer', 'normal development').

If strict filtering returns fewer than ``n_datasets`` results, the function
progressively relaxes filters (drop tissue first, then organism) and
includes a warning in the output.

Returns:
str: Formatted string with dataset info. Example output:
Expand Down Expand Up @@ -206,22 +213,55 @@ def search_czi_datasets(

df = pd.read_csv(metadata_path)

# Pre-filter by organism and tissue before embedding search
if organism:
mask = df["organism"] == organism
df_filtered = df[mask]
if len(df_filtered) >= n_datasets:
df = df_filtered

if tissue:
tissue_lower = tissue.lower()
mask = (
df["tissue"].str.lower().str.contains(tissue_lower, na=False) |
df["tissue_general"].str.lower().str.contains(tissue_lower, na=False)
)
df_filtered = df[mask]
if len(df_filtered) >= n_datasets:
df = df_filtered
# Pre-filter by organism and tissue with controlled relaxation
filter_warnings = []
applied_organism = None
applied_tissue = None

if organism or tissue:
# Try strict filter: both organism and tissue
df_strict = df
if organism:
df_strict = df_strict[df_strict["organism"] == organism]
if tissue:
tissue_lower = tissue.lower()
df_strict = df_strict[
df_strict["tissue"].str.lower().str.contains(tissue_lower, na=False)
| df_strict["tissue_general"].str.lower().str.contains(tissue_lower, na=False)
]

if len(df_strict) >= n_datasets:
df = df_strict
applied_organism = organism
applied_tissue = tissue
elif organism and tissue:
# Relax: drop tissue, keep organism only
df_org_only = df[df["organism"] == organism]
if len(df_org_only) >= n_datasets:
df = df_org_only
applied_organism = organism
filter_warnings.append(
f"Note: Only {len(df_strict)} datasets matched "
f"organism='{organism}' AND tissue='{tissue}' "
f"(fewer than {n_datasets} requested). "
f"Relaxed to organism='{organism}' only "
f"({len(df_org_only)} datasets)."
)
else:
filter_warnings.append(
f"Note: Only {len(df_org_only)} datasets matched "
f"organism='{organism}' (fewer than {n_datasets} "
f"requested). Using all {len(df)} datasets."
)
else:
# Single filter with too few results
filter_name = "organism" if organism else "tissue"
filter_val = organism if organism else tissue
filter_warnings.append(
f"Note: Only {len(df_strict)} datasets matched "
f"{filter_name}='{filter_val}' (fewer than {n_datasets} "
f"requested). Using all {len(df)} datasets."
)

df = df.reset_index(drop=True)

Expand All @@ -239,12 +279,12 @@ def search_czi_datasets(
query_embedding = _embed_with_retry(llm_embed_query, [query])

# Check for cached description embeddings (use effective_model for cache key)
# Include filter params in database identifier to avoid cache collisions
# Include actually-applied filters in identifier to avoid cache collisions
db_id = "czi_census"
if organism:
db_id += f"_{organism.replace(' ', '_')}"
if tissue:
db_id += f"_{tissue.lower()}"
if applied_organism:
db_id += f"_{applied_organism.replace(' ', '_')}"
if applied_tissue:
db_id += f"_{applied_tissue.lower()}"
cache_key = _get_cache_key(db_id, effective_model, len(descriptions))
desc_embeddings = _load_cached_embeddings(cache_key)

Expand Down Expand Up @@ -276,6 +316,8 @@ def search_czi_datasets(
results.append(result_str)

output = "\n".join(results)
if filter_warnings:
output = "\n".join(filter_warnings) + "\n\n" + output
return output


Expand Down