Skip to content

Add HealDA dataloader protocols and init recipe#1555

Open
pzharrington wants to merge 9 commits intoNVIDIA:mainfrom
pzharrington:healda-data
Open

Add HealDA dataloader protocols and init recipe#1555
pzharrington wants to merge 9 commits intoNVIDIA:mainfrom
pzharrington:healda-data

Conversation

@pzharrington
Copy link
Copy Markdown
Collaborator

@pzharrington pzharrington commented Apr 9, 2026

PhysicsNeMo Pull Request

Description

Adds the HealDA data loader system to physicsnemo.experimental.datapipes. Focused initially on reproducibility, preserving performance features NVR developed, and establishing clear interfaces for users interested in extending with custom data.

Also brings in the unit testing for these components, currently living in the recipe folder.

Documentation is mostly in the recipe readme, some copied here for reference:

The physicsnemo.experimental.datapipes.healda package provides a composable data loading pipeline with clear extension points. The architecture separates components into loaders, transforms, datasets, and sampling infrastructure.

Architecture

ObsERA5Dataset(era5_data, obs_loader, transform)
  |  Temporal windowing via FrameIndexGenerator
  |  __getitems__ -> get() per index -> transform.transform()
  v
RestartableDistributedSampler (stateful distributed sampling with checkpointing)
  |
DataLoader (pin_memory, persistent_workers)
  |
prefetch_map(loader, transform.device_transform)
  |
Training loop (GPU-ready batch)

Key Protocols

Custom data sources and transforms plug in via these protocols
(see physicsnemo.experimental.datapipes.healda.protocols):

ObsLoader — the observation loading interface:

class MyObsLoader:
    async def sel_time(self, times):
        """Return {"obs": [pa.Table, ...]}"""
        ...

Transform / DeviceTransform — two-stage batch
processing:

class MyTransform:
    def transform(self, times, frames):
        """CPU-side: normalize, encode obs, time features."""
        ...

    def device_transform(self, batch, device):
        """GPU-side: move to device, compute obs features."""
        ...

Provided Implementations

Component Module Description
ObsERA5Dataset dataset ERA5 state + observations
UFSUnifiedLoader loaders.ufs_obs Parquet obs loader
ERA5Loader loaders.era5 Async ERA5 zarr loader
ERA5ObsTransform transforms.era5_obs Two-stage transform
RestartableDistributedSampler samplers Stateful distributed sampler
prefetch_map prefetch CUDA stream prefetching

All modules above are under
physicsnemo.experimental.datapipes.healda.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington pzharrington self-assigned this Apr 9, 2026
@pzharrington
Copy link
Copy Markdown
Collaborator Author

pzharrington commented Apr 9, 2026

Testing script comparing against reference loader:

Click to expand code
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare outputs between the ported data loader and the reference implementation.

Modes:
    smoketest  — 3 sample indices, fast sanity check (~1 min)
    full       — 50 indices spread across train split (~10 min)

Usage:
    # From examples/weather/healda/
    python scripts/compare_loaders.py smoketest
    python scripts/compare_loaders.py full
    python scripts/compare_loaders.py full --indices 0 100 500 1000

Requires:
    - The reference codebase importable (healda-reference/src on PYTHONPATH, or
      the healda package installed).
    - Environment variables from .env (ERA5_74VAR, UFS_OBS_PATH, etc.).
"""

from __future__ import annotations

import argparse
import os
import sys
import time

import numpy as np
import pandas as pd
import torch

# ---------------------------------------------------------------------------
# Path setup
# ---------------------------------------------------------------------------
RECIPE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
REFERENCE_ROOT = os.path.join(os.path.dirname(RECIPE_ROOT), "healda-reference")

# Load .env
from dotenv import load_dotenv

load_dotenv(os.path.join(RECIPE_ROOT, ".env"))


# ============================================================================
# Ported loader construction
# ============================================================================


def build_ported_dataset(split="train", sensors=None):
    """Construct ObsERA5Dataset from the ported data/ package."""
    from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS
    from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset
    from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import UFSUnifiedLoader
    from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform

    variable_config = VARIABLE_CONFIGS["era5"]

    if sensors is None:
        sensors = ["atms", "mhs", "amsua", "amsub"]

    obs_path = os.environ["UFS_OBS_PATH"]
    obs_loader = UFSUnifiedLoader(
        data_path=obs_path,
        sensors=sensors,
        normalization="zscore",
        obs_context_hours=(-21, 3),
    )

    transform = ERA5ObsTransform(variable_config=variable_config, sensors=sensors)

    import xarray

    era5_path = os.environ["ERA5_74VAR"]
    era5_ds = xarray.open_zarr(era5_path, chunks=None)
    era5_data = era5_ds["data"]

    dataset = ObsERA5Dataset(
        era5_data=era5_data,
        obs_loader=obs_loader,
        transform=transform,
        variable_config=variable_config,
        split=split,
    )
    return dataset


# ============================================================================
# Reference loader construction
# ============================================================================


def build_reference_dataset(split="train", sensors=None):
    """Construct ObsERA5Dataset from the reference healda-reference codebase.

    Requires healda-reference/src and healda-reference/ on sys.path.
    """
    # Add reference paths
    ref_src = os.path.join(REFERENCE_ROOT, "src")
    ref_private = REFERENCE_ROOT
    for p in [ref_src, ref_private]:
        if p not in sys.path:
            sys.path.insert(0, p)

    import dotenv as _dotenv

    _dotenv.load_dotenv(os.path.join(RECIPE_ROOT, ".env"))

    from healda.config.models import ObsConfig
    from private.fcn3_dataset import ObsERA5Dataset as RefObsERA5Dataset

    # Build obs_config matching default sensors
    use_conv = sensors is not None and "conv" in sensors
    obs_config = ObsConfig(
        use_obs=True,
        innovation_type="none",
        context_start=-21,
        context_end=3,
        use_conv=use_conv,
    )

    dataset = RefObsERA5Dataset(
        split=split,
        time_length=1,
        frame_step=1,
        model_rank=0,
        model_world_size=1,
        obs_config=obs_config,
    )
    return dataset


# ============================================================================
# Comparison logic
# ============================================================================


def compare_single_sample(ported_ds, ref_ds, idx: int, verbose: bool = True):
    """Compare a single sample between ported and reference datasets.

    Returns a dict of comparison results.
    """
    results = {"idx": idx, "pass": True, "errors": []}

    # --- Raw data comparison (before transform) ---
    try:
        t0 = time.time()
        ported_times, ported_objs = ported_ds.get(idx)
        ported_elapsed = time.time() - t0

        t0 = time.time()
        ref_times, ref_objs = ref_ds.get(idx)
        ref_elapsed = time.time() - t0

        results["ported_time_s"] = ported_elapsed
        results["ref_time_s"] = ref_elapsed

    except Exception as e:
        results["pass"] = False
        results["errors"].append(f"Loading failed: {e}")
        return results

    # Compare timestamps
    for i, (pt, rt) in enumerate(zip(ported_times, ref_times)):
        if str(pt) != str(rt):
            results["pass"] = False
            results["errors"].append(f"Time mismatch at frame {i}: {pt} vs {rt}")

    # Compare state arrays
    for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)):
        p_state = po["state"]
        r_state = ro["state"]

        if p_state.shape != r_state.shape:
            results["pass"] = False
            results["errors"].append(
                f"State shape mismatch at frame {i}: {p_state.shape} vs {r_state.shape}"
            )
            continue

        max_diff = np.max(np.abs(p_state - r_state))
        results[f"state_frame{i}_maxdiff"] = float(max_diff)

        if max_diff > 1e-6:
            results["pass"] = False
            results["errors"].append(
                f"State value mismatch at frame {i}: max_diff={max_diff:.2e}"
            )

    # Compare observation tables
    # Note: ported and reference may produce rows in different order (due to
    # platform grouping within parquet row-groups).  This is benign — the
    # downstream transform processes all obs in a window together.  We sort
    # both tables by a canonical key before value comparison.
    import pyarrow.compute as pc

    def _sort_obs_table(table):
        """Sort by (Global_Channel_ID, Latitude, Longitude, Absolute_Obs_Time)
        to produce a deterministic row order for comparison."""
        sort_keys = [
            ("Global_Channel_ID", "ascending"),
            ("Latitude", "ascending"),
            ("Longitude", "ascending"),
            ("Absolute_Obs_Time", "ascending"),
        ]
        indices = pc.sort_indices(table, sort_keys=sort_keys)
        return table.take(indices)

    for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)):
        p_obs = po.get("obs")
        r_obs = ro.get("obs") or ro.get("obs_v2")  # reference uses legacy key

        if p_obs is None and r_obs is None:
            continue
        if (p_obs is None) != (r_obs is None):
            results["pass"] = False
            results["errors"].append(f"Obs presence mismatch at frame {i}")
            continue

        p_nrows = p_obs.num_rows
        r_nrows = r_obs.num_rows
        results[f"obs_frame{i}_nrows_ported"] = p_nrows
        results[f"obs_frame{i}_nrows_ref"] = r_nrows

        if p_nrows != r_nrows:
            results["pass"] = False
            results["errors"].append(
                f"Obs row count mismatch at frame {i}: {p_nrows} vs {r_nrows}"
            )
            continue

        if p_nrows > 0:
            # Compare schemas
            p_cols = set(p_obs.schema.names)
            r_cols = set(r_obs.schema.names)
            if p_cols != r_cols:
                results["pass"] = False
                results["errors"].append(
                    f"Obs schema mismatch at frame {i}: "
                    f"ported_only={p_cols - r_cols}, ref_only={r_cols - p_cols}"
                )
                continue

            # Sort both tables to canonical order before comparison
            p_sorted = _sort_obs_table(p_obs)
            r_sorted = _sort_obs_table(r_obs)

            # Compare observation values
            p_vals = p_sorted["Observation"].to_numpy()
            r_vals = r_sorted["Observation"].to_numpy()
            obs_max_diff = np.nanmax(np.abs(p_vals - r_vals))
            results[f"obs_frame{i}_val_maxdiff"] = float(obs_max_diff)
            if obs_max_diff > 1e-5:
                results["pass"] = False
                results["errors"].append(
                    f"Obs value mismatch at frame {i}: max_diff={obs_max_diff:.2e}"
                )

            # Also verify Global_Channel_ID sets match
            p_gcids = set(p_obs["Global_Channel_ID"].to_pylist())
            r_gcids = set(r_obs["Global_Channel_ID"].to_pylist())
            if p_gcids != r_gcids:
                results["pass"] = False
                results["errors"].append(
                    f"Obs GCID set mismatch at frame {i}: "
                    f"ported_only={p_gcids - r_gcids}, ref_only={r_gcids - p_gcids}"
                )

    if verbose:
        status = "PASS" if results["pass"] else "FAIL"
        timing = (
            f"ported={results.get('ported_time_s', 0):.2f}s "
            f"ref={results.get('ref_time_s', 0):.2f}s"
        )
        print(f"  [{status}] idx={idx:6d}  {timing}")
        for err in results["errors"]:
            print(f"         {err}")

    return results


def compare_transformed_sample(ported_ds, ref_ds, idx: int, verbose: bool = True):
    """Compare transformed (batched) output between ported and reference.

    Uses __getitems__ to exercise the full transform pipeline.
    """
    results = {"idx": idx, "pass": True, "errors": []}

    try:
        t0 = time.time()
        ported_batch = ported_ds.__getitems__([idx])
        ported_elapsed = time.time() - t0

        t0 = time.time()
        ref_batch = ref_ds.__getitems__([idx])
        ref_elapsed = time.time() - t0

        results["ported_transform_s"] = ported_elapsed
        results["ref_transform_s"] = ref_elapsed

    except Exception as e:
        results["pass"] = False
        results["errors"].append(f"Transform failed: {e}")
        if verbose:
            print(f"  [FAIL] idx={idx:6d} Transform error: {e}")
        return results

    # Compare batch dict keys
    p_keys = set(ported_batch.keys())
    r_keys = set(ref_batch.keys())
    if p_keys != r_keys:
        results["errors"].append(
            f"Batch key mismatch: ported_only={p_keys - r_keys}, ref_only={r_keys - p_keys}"
        )
        # Don't fail — extra/missing keys may be intentional

    # Compare tensor fields
    for key in sorted(p_keys & r_keys):
        pv = ported_batch[key]
        rv = ref_batch[key]

        if isinstance(pv, torch.Tensor) and isinstance(rv, torch.Tensor):
            if pv.shape != rv.shape:
                results["pass"] = False
                results["errors"].append(
                    f"Shape mismatch for '{key}': {pv.shape} vs {rv.shape}"
                )
                continue

            if pv.numel() == 0:
                continue
            max_diff = (pv.float() - rv.float()).abs().max().item()
            results[f"{key}_maxdiff"] = max_diff

            # Use loose tolerance for float transforms
            tol = 1e-4 if pv.is_floating_point() else 0
            if max_diff > tol:
                results["pass"] = False
                results["errors"].append(
                    f"Value mismatch for '{key}': max_diff={max_diff:.2e}"
                )

        elif isinstance(pv, tuple) and isinstance(rv, tuple):
            # unified_obs is a tuple (obs_tensors, lengths_3d)
            # Row ordering may differ between ported and reference (benign —
            # within each sensor group, platforms can appear in different order
            # depending on parquet row-group layout).  We sort both by
            # (global_channel_id, latitude, longitude) before comparing values.
            if len(pv) != len(rv):
                results["pass"] = False
                results["errors"].append(
                    f"Tuple length mismatch for '{key}': {len(pv)} vs {len(rv)}"
                )
                continue

            if isinstance(pv[0], dict) and isinstance(rv[0], dict):
                p_obs_keys = set(pv[0].keys())
                r_obs_keys = set(rv[0].keys())
                if p_obs_keys != r_obs_keys:
                    results["errors"].append(
                        f"Obs tensor key mismatch: "
                        f"ported_only={p_obs_keys - r_obs_keys}, "
                        f"ref_only={r_obs_keys - p_obs_keys}"
                    )

                # Build a stable sort index using torch.lexsort-style
                # multi-key sorting: (gcid, abs_time, lat, lon, observation)
                def _sort_idx(obs_dict):
                    gcid = obs_dict.get("global_channel_id")
                    lat = obs_dict.get("latitude")
                    lon = obs_dict.get("longitude")
                    obs_time = obs_dict.get("absolute_obs_time")
                    obs_val = obs_dict.get("observation")
                    if gcid is None or gcid.numel() == 0:
                        return None
                    # Stack columns as (N, K) float64 for lexicographic sort.
                    # torch.lexsort isn't available, so we use numpy.
                    cols = [gcid.double().cpu().numpy()]
                    if obs_time is not None:
                        cols.append(obs_time.double().cpu().numpy())
                    if lat is not None:
                        cols.append(lat.double().cpu().numpy())
                    if lon is not None:
                        cols.append(lon.double().cpu().numpy())
                    if obs_val is not None:
                        cols.append(obs_val.double().cpu().numpy())
                    # np.lexsort sorts by last key first, so reverse
                    order = np.lexsort(cols[::-1])
                    return torch.from_numpy(order).long()

                p_order = _sort_idx(pv[0])
                r_order = _sort_idx(rv[0])

                for obs_key in sorted(p_obs_keys & r_obs_keys):
                    pt = pv[0][obs_key]
                    rt = rv[0][obs_key]
                    if pt.shape != rt.shape:
                        results["pass"] = False
                        results["errors"].append(
                            f"Obs tensor shape mismatch for '{obs_key}': "
                            f"{pt.shape} vs {rt.shape}"
                        )
                    elif pt.numel() > 0:
                        # Apply sort order before comparison
                        ps = pt[p_order] if p_order is not None else pt
                        rs = rt[r_order] if r_order is not None else rt
                        d = (ps.float() - rs.float()).abs().max().item()
                        results[f"obs_{obs_key}_maxdiff"] = d
                        if d > 1e-4:
                            results["pass"] = False
                            results["errors"].append(
                                f"Obs tensor mismatch for '{obs_key}': "
                                f"max_diff={d:.2e}"
                            )

            # Compare lengths_3d (sensor, batch, time) — these count obs per
            # sensor/window and are order-independent as long as sensor_id
            # mapping matches.
            for ti, name in [(1, "lengths")]:
                if ti < len(pv) and ti < len(rv):
                    pt, rt = pv[ti], rv[ti]
                    if isinstance(pt, torch.Tensor) and isinstance(rt, torch.Tensor):
                        if pt.shape != rt.shape:
                            results["pass"] = False
                            results["errors"].append(
                                f"{name} shape mismatch: {pt.shape} vs {rt.shape}"
                            )
                        elif not torch.equal(pt, rt):
                            results["pass"] = False
                            results["errors"].append(f"{name} value mismatch")

    if verbose:
        status = "PASS" if results["pass"] else "FAIL"
        timing = (
            f"ported={results.get('ported_transform_s', 0):.2f}s "
            f"ref={results.get('ref_transform_s', 0):.2f}s"
        )
        print(f"  [{status}] idx={idx:6d}  {timing}")
        for err in results["errors"]:
            print(f"         {err}")

    return results


# ============================================================================
# Main driver
# ============================================================================


def get_indices(mode: str, ds_len: int, custom_indices=None):
    """Return sample indices based on mode."""
    if custom_indices:
        return [i for i in custom_indices if i < ds_len]

    if mode == "smoketest":
        # 3 indices: start, middle, near end
        return [0, ds_len // 2, ds_len - 1]

    elif mode == "full":
        # 50 indices spread across the dataset
        n = min(50, ds_len)
        step = max(1, ds_len // n)
        return list(range(0, ds_len, step))[:n]

    else:
        raise ValueError(f"Unknown mode: {mode}")


def main():
    parser = argparse.ArgumentParser(
        description="Compare ported vs reference data loader outputs."
    )
    parser.add_argument(
        "mode",
        choices=["smoketest", "full"],
        help="smoketest: 3 indices, fast. full: 50 indices.",
    )
    parser.add_argument(
        "--indices",
        type=int,
        nargs="*",
        default=None,
        help="Override indices to compare.",
    )
    parser.add_argument(
        "--split", default="train", help="Dataset split (default: train)."
    )
    parser.add_argument(
        "--sensors",
        nargs="*",
        default=None,
        help="Sensor list (default: atms mhs amsua amsub).",
    )
    parser.add_argument(
        "--transform",
        action="store_true",
        help="Also compare transformed (__getitems__) output.",
    )
    parser.add_argument(
        "--no-raw",
        action="store_true",
        help="Skip raw (get) comparison, only do transform.",
    )
    args = parser.parse_args()

    print("=" * 70)
    print(f"Loader comparison — mode={args.mode}, split={args.split}")
    print("=" * 70)

    # Build datasets
    print("\nBuilding ported dataset...")
    t0 = time.time()
    ported_ds = build_ported_dataset(split=args.split, sensors=args.sensors)
    print(f"  Done in {time.time() - t0:.1f}s  (len={len(ported_ds)})")

    print("Building reference dataset...")
    t0 = time.time()
    ref_ds = build_reference_dataset(split=args.split, sensors=args.sensors)
    print(f"  Done in {time.time() - t0:.1f}s  (len={len(ref_ds)})")

    # Verify lengths match
    if len(ported_ds) != len(ref_ds):
        print(
            f"\nWARNING: Dataset lengths differ! "
            f"ported={len(ported_ds)} vs ref={len(ref_ds)}"
        )

    indices = get_indices(
        args.mode, min(len(ported_ds), len(ref_ds)), args.indices
    )
    print(f"\nComparing {len(indices)} samples: {indices[:10]}{'...' if len(indices) > 10 else ''}")

    # --- Raw comparison ---
    if not args.no_raw:
        print(f"\n--- Raw comparison (get) ---")
        raw_results = []
        for idx in indices:
            r = compare_single_sample(ported_ds, ref_ds, idx)
            raw_results.append(r)

        n_pass = sum(1 for r in raw_results if r["pass"])
        n_fail = len(raw_results) - n_pass
        print(f"\nRaw: {n_pass}/{len(raw_results)} passed, {n_fail} failed")

    # --- Transform comparison ---
    if args.transform:
        print(f"\n--- Transform comparison (__getitems__) ---")
        xform_results = []
        for idx in indices:
            r = compare_transformed_sample(ported_ds, ref_ds, idx)
            xform_results.append(r)

        n_pass = sum(1 for r in xform_results if r["pass"])
        n_fail = len(xform_results) - n_pass
        print(f"\nTransform: {n_pass}/{len(xform_results)} passed, {n_fail} failed")

    # --- Summary ---
    print("\n" + "=" * 70)
    all_results = []
    if not args.no_raw:
        all_results.extend(raw_results)
    if args.transform:
        all_results.extend(xform_results)

    n_total = len(all_results)
    n_pass = sum(1 for r in all_results if r["pass"])
    if n_total == n_pass:
        print(f"ALL {n_total} CHECKS PASSED")
    else:
        print(f"{n_total - n_pass}/{n_total} CHECKS FAILED")
        sys.exit(1)


if __name__ == "__main__":
    main()

Comment thread physicsnemo/experimental/datapipes/healda/prefetch.py
Comment thread physicsnemo/experimental/datapipes/healda/configs/variable_configs.py Outdated
Comment thread examples/weather/healda/requirements.txt Outdated
Copy link
Copy Markdown
Contributor

@aayushg55 aayushg55 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a final pass. Looks good, I should be able to extend the obs transform easily. Just need to update the requirements and rename the platform field.

Comment thread physicsnemo/experimental/datapipes/healda/types.py Outdated
@pzharrington pzharrington marked this pull request as ready for review April 17, 2026 00:12
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 17, 2026

Greptile Summary

This PR introduces the physicsnemo.experimental.datapipes.healda package — a composable ERA5 + satellite/conventional observation data pipeline for HealDA training, including async zarr/parquet loaders, two-stage CPU/GPU transforms, a stateful distributed sampler, and CUDA-stream prefetching.

  • P1 — split_array_contiguous crashes on single-element arrays (indexing.py:44): d = x[1] - x[0] raises IndexError whenever the time array has exactly one element; only size == 0 is guarded.
  • P1 — Multi-window row-group data loss in _iterate_parquet_da_windows (ufs_obs.py:207–229): when a parquet row group spans multiple DA windows, this_window is overwritten for each match and only the last one is yielded; observations from earlier windows end up stored under the wrong time key or silently dropped.

Important Files Changed

Filename Overview
physicsnemo/experimental/datapipes/healda/indexing.py New temporal indexing module. Two bugs: split_array_contiguous crashes with IndexError on single-element arrays; _map_logical_to_physical uses total_samples instead of valid_length for the bounds check.
physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py New observation loader. Three issues: (1) _iterate_parquet_da_windows silently drops/misattributes observations when a parquet row group spans multiple DA windows; (2) channel_table local_channel_id computation uses Python wrap-around at i=0; (3) fixed_range normalization hardcodes [0, 400] instead of per-channel min_valid/max_valid.
physicsnemo/experimental/datapipes/healda/samplers.py New stateful distributed sampler. Rank-specific RNG seeds produce independent (not partitioned) permutations per rank, so uniform dataset coverage per epoch is not guaranteed; this is intentional but undocumented.
physicsnemo/experimental/datapipes/healda/dataset.py New map-style dataset combining ERA5 state with observations; clean implementation with good docstrings and test coverage.
physicsnemo/experimental/datapipes/healda/prefetch.py New background CUDA prefetch iterator; well-structured with proper error propagation. Minor: _stop() may leave the worker thread stuck on queue.put() when the queue is full and the consumer is gone, but daemon=True limits the blast radius.
physicsnemo/experimental/datapipes/healda/protocols.py Defines ObsLoader, Transform, and DeviceTransform protocols; well-documented and correctly uses runtime_checkable.
physicsnemo/experimental/datapipes/healda/loaders/era5.py ERA5 zarr loader with normalization stats; adds renamed keys to data without removing originals (harmless since _collect_fields uses index-based lookup), but slightly wasteful.
physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py Two-stage CPU+GPU transform; correctly separates DataLoader-worker CPU processing from CUDA-stream GPU featurization.

Reviews (1): Last reviewed commit: "Revert precommit change" | Re-trigger Greptile

Comment on lines +43 to +45
segments = []
start = 0
for i in range(1, x.size):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 IndexError for single-element time array

d = x[1] - x[0] is evaluated before the loop for every non-empty array, so passing an array with exactly one element (x.size == 1) raises IndexError: index 1 is out of bounds for axis 0 with size 1. Only x.size == 0 is guarded. Any dataset constructed with a single timestamp (common in unit tests or small experiments) will crash here.

Suggested change
segments = []
start = 0
for i in range(1, x.size):
if x.size <= 1:
return [x] if x.size == 1 else []
d = x[1] - x[0]

Comment on lines +207 to +229

this_window = None
for w in target_windows:
if row_group_lo <= w <= row_group_hi:
this_window = w

if this_window is None:
continue

table = parquet.read_row_group(
row_group_idx, columns=self._read_columns
)

if row_group_lo != row_group_hi:
mask = pc.is_in(
table["DA_window"], pa.array(list(target_windows))
)
table = table.filter(mask)

if table.num_rows == 0:
continue

yield this_window, table
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Multi-window row groups silently merge under wrong key

When row_group_lo != row_group_hi, multiple target windows may satisfy row_group_lo <= w <= row_group_hi. The inner loop overwrites this_window on each match, so only the last matching window is used. The table is filtered to keep rows from all matching windows, then yielded once under that last-window label.

The downstream tables.setdefault(interval_time, []).append(table) in sel_time therefore stores a multi-window table under only one key; the other matching windows get [] when looked up in process(t). Observations from earlier DA windows end up attributed to the wrong time step and are dropped entirely for their actual slot.

A fix is to yield one sub-table per matching window:

for w in target_windows:
    if row_group_lo <= w <= row_group_hi:
        if row_group_lo != row_group_hi:
            mask = pc.equal(table["DA_window"], w)
            sub = table.filter(mask)
        else:
            sub = table
        if sub.num_rows > 0:
            yield w, sub

Comment on lines +163 to +167
offset = 0
for i in range(len(sensor_id)):
if sensor_id[i] != sensor_id[i - 1]:
offset = i
local_channel_ids.append(i - offset)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Off-by-one via Python negative index at i = 0

At the first loop iteration (i = 0), sensor_id[i - 1] evaluates to sensor_id[-1] — the last element of the array — due to Python's wrap-around indexing. If the first and last rows share the same sensor_id, offset is never reset, and the first channel gets a local_channel_id offset from 0 instead of 0.

The intent is to assign a 0-based local ID within each contiguous sensor block. A clearer, bug-free approach:

Suggested change
offset = 0
for i in range(len(sensor_id)):
if sensor_id[i] != sensor_id[i - 1]:
offset = i
local_channel_ids.append(i - offset)
for i in range(len(sensor_id)):
if i == 0 or sensor_id[i] != sensor_id[i - 1]:
offset = i
local_channel_ids.append(i - offset)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, this can never lead to a bug since the offset is initialized as 0, so in either case, the offset is 0 at i=0. But agree this would better reflect the intent

Comment on lines +241 to +244
def _normalize_observations(self, table: pa.Table) -> pa.Table:
if self.normalization == "fixed_range":
normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0)
elif self.normalization == "zscore":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 fixed_range normalization ignores per-channel min_valid/max_valid

The fixed_range branch hard-codes 0 and 400 as global min/max:

normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0)

The channel table already supplies min_valid and max_valid per channel (loaded in _add_channel_metadata), but they are not used here. This likely produces incorrectly scaled observations for any channel whose valid range differs from [0, 400], which is almost every instrument (e.g. AMSUA channel 1 mean ≈ 208 K, see amsua_normalizations.csv).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only use zscore normalization, I don't think it would ever make sense to use anything else. Probably best to remove the deprecated code.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Here's what to clean up:

  1. Remove the normalization parameter from __init__ (and its docstring entry)
  2. Drop self.normalization assignment
  3. Simplify _normalize_observations to always use zscore
  4. Update get_channel_table to drop the now-unnecessary normalization="zscore" kwarg

The simplified method would look like:

Suggested change
def _normalize_observations(self, table: pa.Table) -> pa.Table:
if self.normalization == "fixed_range":
normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0)
elif self.normalization == "zscore":
def _normalize_observations(self, table: pa.Table) -> pa.Table:
normalized = pc.divide(
pc.subtract(table["Observation"], table["mean"]),
table["stddev"],
)

And remove this line from the __init__ signature:

        normalization: Literal["fixed_range", "zscore"] = "fixed_range",

Note: the default was "fixed_range", so any existing callers not explicitly passing normalization="zscore" were silently using the broken path. Removing the param entirely is safe since zscore is the only correct behavior.

Tip: You can customize Greptile's behavior for this repo with .greptile/rules and .greptile/config.

Comment on lines +66 to +68
self.iteration = 0
rng = torch.Generator().manual_seed(self.seed + self.epoch + self.rank)
permutation = torch.randperm(self.len, generator=rng)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Rank-specific seeds break uniform dataset coverage

Each rank seeds its RNG with self.seed + self.epoch + self.rank, so rank 0 and rank 1 generate different random permutations. Rank 0 then takes indices [0, num_replicas, 2*num_replicas, …] of its permutation, while rank 1 does the same from a different permutation. The combined set of indices seen by all ranks in one epoch is not a partition of [0, dataset_size) — samples may be duplicated across ranks or missed entirely.

The standard approach (PyTorch's DistributedSampler) uses the same seed on all ranks so the permutation is identical, then partitions it by rank. This PR intentionally diverges from that pattern (as evidenced by the test_multi_replica_independent test), which means per-epoch sample coverage is not guaranteed. This is an intentional design choice but should be documented explicitly in the class docstring so users understand the trade-off.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Agree that given we assign tasks per rank based on self.permutation = permutation[self.rank :: self.num_replicas], with the intention that all have the same permutation to begin with. So we should drop the +self.rank.

However, the interpretation of test_multi_replica_independent is wrong; it is simply covering that different ranks recieve different tasks.

Comment on lines +142 to +149
)

segment_idx = 0
for i, cum_size in enumerate(self.cumulative_valid_sizes[1:], 1):
if logical_idx < cum_size:
segment_idx = i - 1
break

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Bounds check uses total_samples instead of valid_length

if logical_idx >= self.total_samples:
    raise IndexError(...)

total_samples is the sum of all segment sizes (including frames that can't start a valid window), while logical_idx is always expected to be in [0, valid_length). A value in [valid_length, total_samples) passes this guard silently and maps to a physical frame index that lies inside the last segment but beyond the last valid window, potentially yielding out-of-bounds xarray .isel() coordinates.

Suggested change
)
segment_idx = 0
for i, cum_size in enumerate(self.cumulative_valid_sizes[1:], 1):
if logical_idx < cum_size:
segment_idx = i - 1
break
if logical_idx >= self.valid_length:
raise IndexError(
f"Sample index {logical_idx} out of bounds "
f"for {self.valid_length} valid samples"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should replace the bounds check to be against self.valid_length instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants