Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions malariagen_data/anoph/ld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Optional

import numpy as np

import allel # type: ignore
import xarray as xr
from numpydoc_decorator import doc # type: ignore

from ..util import _check_types, _dask_compress_dataset
from . import base_params, ld_params, pca_params
from .snp_data import AnophelesSnpData


class AnophelesLdAnalysis(
AnophelesSnpData,
):
def __init__(
self,
**kwargs,
):
# N.B., this class is designed to work cooperatively, and
# so it's important that any remaining parameters are passed
# to the superclass constructor.
super().__init__(**kwargs)

@_check_types
@doc(
summary="""
Access biallelic SNP calls after LD pruning.
""",
extended_summary="""
This function obtains biallelic SNP calls, then performs LD pruning
using scikit-allel's `locate_unlinked` function. The resulting dataset
can be used as input to ADMIXTURE workflows or exported to PLINK format.

LD pruning is controlled by three parameters:

- `ld_window_size`: number of SNPs in the sliding window used to
compute pairwise r-squared.
- `ld_window_step`: number of SNPs to advance the window each
iteration.
- `ld_threshold`: maximum r-squared value; SNP pairs above this
are considered linked and one will be removed.

Note that `n_snps` is required to control memory usage. Without
pre-thinning, LD pruning could attempt to materialise millions of
variants and run out of memory.
""",
returns="""
A dataset of LD-pruned biallelic SNP calls with the same structure as
the output of `biallelic_snp_calls`.
""",
)
def biallelic_snp_calls_ld_pruned(
self,
region: base_params.regions,
n_snps: base_params.n_snps,
ld_window_size: ld_params.ld_window_size = ld_params.ld_window_size_default,
ld_window_step: ld_params.ld_window_step = ld_params.ld_window_step_default,
ld_threshold: ld_params.ld_threshold = ld_params.ld_threshold_default,
thin_offset: base_params.thin_offset = 0,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
sample_query_options: Optional[base_params.sample_query_options] = None,
sample_indices: Optional[base_params.sample_indices] = None,
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
min_minor_ac: Optional[
base_params.min_minor_ac
] = pca_params.min_minor_ac_default,
max_missing_an: Optional[
base_params.max_missing_an
] = pca_params.max_missing_an_default,
random_seed: base_params.random_seed = 42,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.native_chunks,
) -> xr.Dataset:
# Check that either sample_query xor sample_indices are provided.
base_params._validate_sample_selection_params(
sample_query=sample_query, sample_indices=sample_indices
)

# Validate LD parameters.
if ld_window_size <= 0:
raise ValueError(f"ld_window_size must be > 0, got {ld_window_size}")
if ld_window_step <= 0:
raise ValueError(f"ld_window_step must be > 0, got {ld_window_step}")
if not (0 < ld_threshold <= 1):
raise ValueError(f"ld_threshold must be in (0, 1], got {ld_threshold}")

# Obtain biallelic SNP calls with thinning applied first.
ds_snps = self.biallelic_snp_calls(
region=region,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
sample_indices=sample_indices,
site_mask=site_mask,
min_minor_ac=min_minor_ac,
max_missing_an=max_missing_an,
n_snps=n_snps,
thin_offset=thin_offset,
random_seed=random_seed,
inline_array=inline_array,
chunks=chunks,
)

# Compute genotype reference counts.
with self._dask_progress(desc="Computing genotype ref counts"):
gt = ds_snps["call_genotype"].data
gn = allel.GenotypeDaskArray(gt).to_n_ref(fill=-127).compute()

# Perform LD pruning.
with self._spinner(desc="LD pruning"):
loc_unlinked = allel.locate_unlinked(
gn,
size=ld_window_size,
step=ld_window_step,
threshold=ld_threshold,
)

# Guard against empty result.
if not np.any(loc_unlinked):
raise ValueError(
"LD pruning removed all variants. Consider using a less "
"stringent ld_threshold or providing more variants via n_snps."
)

# Apply the pruning mask.
ds_pruned = _dask_compress_dataset(
ds_snps, indexer=loc_unlinked, dim="variants"
)

return ds_pruned
25 changes: 25 additions & 0 deletions malariagen_data/anoph/ld_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Parameters for LD pruning functions."""

from typing_extensions import Annotated, TypeAlias

ld_window_size: TypeAlias = Annotated[
int,
"Window size in number of SNPs for LD pruning.",
]

ld_window_size_default: ld_window_size = 500

ld_window_step: TypeAlias = Annotated[
int,
"Step size in number of SNPs for LD pruning.",
]

ld_window_step_default: ld_window_step = 200

ld_threshold: TypeAlias = Annotated[
float,
"r-squared threshold for LD pruning. SNP pairs with r-squared above "
"this threshold will be considered linked.",
]

ld_threshold_default: ld_threshold = 0.1
2 changes: 2 additions & 0 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .anoph.sample_metadata import AnophelesSampleMetadata
from .anoph.snp_data import AnophelesSnpData
from .anoph.to_plink import PlinkConverter
from .anoph.ld import AnophelesLdAnalysis
from .anoph.to_vcf import SnpVcfExporter
from .anoph.g123 import AnophelesG123Analysis
from .anoph.fst import AnophelesFstAnalysis
Expand Down Expand Up @@ -88,6 +89,7 @@ class AnophelesDataResource(
AnophelesDistanceAnalysis,
AnophelesPca,
PlinkConverter,
AnophelesLdAnalysis,
SnpVcfExporter,
AnophelesIgv,
AnophelesKaryotypeAnalysis,
Expand Down
216 changes: 216 additions & 0 deletions notebooks/ld_pruning.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
"metadata": {},
"source": [
"# LD-pruned biallelic SNP calls\n",
"\n",
"LD pruning removes redundant SNPs that are in linkage disequilibrium, reducing correlation between variants. This is important for analyses like PCA and ADMIXTURE where independent markers are assumed.\n",
"\n",
"`biallelic_snp_calls_ld_pruned()` wraps `biallelic_snp_calls()` and applies LD pruning via `scikit-allel`'s `locate_unlinked()`. The output retains the same dataset structure and is compatible with existing downstream methods."
]
},
{
"cell_type": "markdown",
"id": "b2c3d4e5-f6a7-8901-bcde-f12345678901",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3d4e5f6-a7b8-9012-cdef-123456789012",
"metadata": {},
"outputs": [],
"source": [
"import malariagen_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4e5f6a7-b8c9-0123-defa-234567890123",
"metadata": {},
"outputs": [],
"source": [
"ag3 = malariagen_data.Ag3(\n",
" \"simplecache::gs://vo_agam_release_master_us_central1\",\n",
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
" results_cache=\"results_cache\",\n",
")\n",
"ag3"
]
},
{
"cell_type": "markdown",
"id": "e5f6a7b8-c9d0-1234-efab-345678901234",
"metadata": {},
"source": [
"## Basic usage"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6a7b8c9-d0e1-2345-fabc-456789012345",
"metadata": {},
"outputs": [],
"source": [
"ds_pruned = ag3.biallelic_snp_calls_ld_pruned(\n",
" region=\"3L\",\n",
" n_snps=100_000,\n",
" sample_sets=\"AG1000G-BF-A\",\n",
")\n",
"ds_pruned"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a7b8c9d0-e1f2-3456-abcd-567890123456",
"metadata": {},
"outputs": [],
"source": [
"# Inspect dimensions.\n",
"print(f\"variants: {ds_pruned.sizes['variants']}\")\n",
"print(f\"samples: {ds_pruned.sizes['samples']}\")\n",
"print(f\"alleles: {ds_pruned.sizes['alleles']}\")"
]
},
{
"cell_type": "markdown",
"id": "b8c9d0e1-f2a3-4567-bcde-678901234567",
"metadata": {},
"source": [
"## Effect of LD parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9d0e1f2-a3b4-5678-cdef-789012345678",
"metadata": {},
"outputs": [],
"source": [
"# Stricter threshold (0.1, default) removes more correlated SNPs.\n",
"ds_strict = ag3.biallelic_snp_calls_ld_pruned(\n",
" region=\"3L\",\n",
" n_snps=100_000,\n",
" sample_sets=\"AG1000G-BF-A\",\n",
" ld_threshold=0.1,\n",
")\n",
"\n",
"# More lenient threshold retains more SNPs.\n",
"ds_lenient = ag3.biallelic_snp_calls_ld_pruned(\n",
" region=\"3L\",\n",
" n_snps=100_000,\n",
" sample_sets=\"AG1000G-BF-A\",\n",
" ld_threshold=0.5,\n",
")\n",
"\n",
"print(f\"strict (threshold=0.1): {ds_strict.sizes['variants']} variants\")\n",
"print(f\"lenient (threshold=0.5): {ds_lenient.sizes['variants']} variants\")"
]
},
{
"cell_type": "markdown",
"id": "d0e1f2a3-b4c5-6789-defa-890123456789",
"metadata": {},
"source": [
"## Before vs after pruning"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1f2a3b4-c5d6-7890-efab-901234567890",
"metadata": {},
"outputs": [],
"source": [
"# Get the thinned (but not LD-pruned) SNPs for comparison.\n",
"ds_before = ag3.biallelic_snp_calls(\n",
" region=\"3L\",\n",
" n_snps=100_000,\n",
" sample_sets=\"AG1000G-BF-A\",\n",
")\n",
"\n",
"print(f\"before LD pruning: {ds_before.sizes['variants']} variants\")\n",
"print(f\"after LD pruning: {ds_pruned.sizes['variants']} variants\")\n",
"print(f\"removed: {ds_before.sizes['variants'] - ds_pruned.sizes['variants']} variants\")"
]
},
{
"cell_type": "markdown",
"id": "f2a3b4c5-d6e7-8901-fabc-012345678901",
"metadata": {},
"source": [
"## Downstream compatibility\n",
"\n",
"The pruned dataset retains the same structure as `biallelic_snp_calls()` output, so it can be passed directly into existing workflows like PCA."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3b4c5d6-e7f8-9012-abcd-123456789abc",
"metadata": {},
"outputs": [],
"source": [
"# Verify the pruned dataset has all variables expected by downstream methods.\n",
"assert \"call_genotype\" in ds_pruned\n",
"assert \"variant_allele\" in ds_pruned\n",
"assert \"variant_contig\" in ds_pruned.coords\n",
"assert \"variant_position\" in ds_pruned.coords\n",
"assert \"sample_id\" in ds_pruned.coords\n",
"\n",
"# Shape sanity check.\n",
"n_variants = ds_pruned.sizes[\"variants\"]\n",
"n_samples = ds_pruned.sizes[\"samples\"]\n",
"assert ds_pruned[\"call_genotype\"].shape == (n_variants, n_samples, 2)\n",
"assert ds_pruned[\"variant_allele\"].shape == (n_variants, 2)\n",
"\n",
"print(f\"Dataset is valid: {n_variants} variants × {n_samples} samples\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4c5d6e7-f8a9-0123-bcde-234567890bcd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading