Add HealDA dataloader protocols and init recipe#1555
Add HealDA dataloader protocols and init recipe#1555pzharrington wants to merge 9 commits intoNVIDIA:mainfrom
Conversation
|
Testing script comparing against reference loader: Click to expand code#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare outputs between the ported data loader and the reference implementation.
Modes:
smoketest — 3 sample indices, fast sanity check (~1 min)
full — 50 indices spread across train split (~10 min)
Usage:
# From examples/weather/healda/
python scripts/compare_loaders.py smoketest
python scripts/compare_loaders.py full
python scripts/compare_loaders.py full --indices 0 100 500 1000
Requires:
- The reference codebase importable (healda-reference/src on PYTHONPATH, or
the healda package installed).
- Environment variables from .env (ERA5_74VAR, UFS_OBS_PATH, etc.).
"""
from __future__ import annotations
import argparse
import os
import sys
import time
import numpy as np
import pandas as pd
import torch
# ---------------------------------------------------------------------------
# Path setup
# ---------------------------------------------------------------------------
RECIPE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
REFERENCE_ROOT = os.path.join(os.path.dirname(RECIPE_ROOT), "healda-reference")
# Load .env
from dotenv import load_dotenv
load_dotenv(os.path.join(RECIPE_ROOT, ".env"))
# ============================================================================
# Ported loader construction
# ============================================================================
def build_ported_dataset(split="train", sensors=None):
"""Construct ObsERA5Dataset from the ported data/ package."""
from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS
from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset
from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import UFSUnifiedLoader
from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform
variable_config = VARIABLE_CONFIGS["era5"]
if sensors is None:
sensors = ["atms", "mhs", "amsua", "amsub"]
obs_path = os.environ["UFS_OBS_PATH"]
obs_loader = UFSUnifiedLoader(
data_path=obs_path,
sensors=sensors,
normalization="zscore",
obs_context_hours=(-21, 3),
)
transform = ERA5ObsTransform(variable_config=variable_config, sensors=sensors)
import xarray
era5_path = os.environ["ERA5_74VAR"]
era5_ds = xarray.open_zarr(era5_path, chunks=None)
era5_data = era5_ds["data"]
dataset = ObsERA5Dataset(
era5_data=era5_data,
obs_loader=obs_loader,
transform=transform,
variable_config=variable_config,
split=split,
)
return dataset
# ============================================================================
# Reference loader construction
# ============================================================================
def build_reference_dataset(split="train", sensors=None):
"""Construct ObsERA5Dataset from the reference healda-reference codebase.
Requires healda-reference/src and healda-reference/ on sys.path.
"""
# Add reference paths
ref_src = os.path.join(REFERENCE_ROOT, "src")
ref_private = REFERENCE_ROOT
for p in [ref_src, ref_private]:
if p not in sys.path:
sys.path.insert(0, p)
import dotenv as _dotenv
_dotenv.load_dotenv(os.path.join(RECIPE_ROOT, ".env"))
from healda.config.models import ObsConfig
from private.fcn3_dataset import ObsERA5Dataset as RefObsERA5Dataset
# Build obs_config matching default sensors
use_conv = sensors is not None and "conv" in sensors
obs_config = ObsConfig(
use_obs=True,
innovation_type="none",
context_start=-21,
context_end=3,
use_conv=use_conv,
)
dataset = RefObsERA5Dataset(
split=split,
time_length=1,
frame_step=1,
model_rank=0,
model_world_size=1,
obs_config=obs_config,
)
return dataset
# ============================================================================
# Comparison logic
# ============================================================================
def compare_single_sample(ported_ds, ref_ds, idx: int, verbose: bool = True):
"""Compare a single sample between ported and reference datasets.
Returns a dict of comparison results.
"""
results = {"idx": idx, "pass": True, "errors": []}
# --- Raw data comparison (before transform) ---
try:
t0 = time.time()
ported_times, ported_objs = ported_ds.get(idx)
ported_elapsed = time.time() - t0
t0 = time.time()
ref_times, ref_objs = ref_ds.get(idx)
ref_elapsed = time.time() - t0
results["ported_time_s"] = ported_elapsed
results["ref_time_s"] = ref_elapsed
except Exception as e:
results["pass"] = False
results["errors"].append(f"Loading failed: {e}")
return results
# Compare timestamps
for i, (pt, rt) in enumerate(zip(ported_times, ref_times)):
if str(pt) != str(rt):
results["pass"] = False
results["errors"].append(f"Time mismatch at frame {i}: {pt} vs {rt}")
# Compare state arrays
for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)):
p_state = po["state"]
r_state = ro["state"]
if p_state.shape != r_state.shape:
results["pass"] = False
results["errors"].append(
f"State shape mismatch at frame {i}: {p_state.shape} vs {r_state.shape}"
)
continue
max_diff = np.max(np.abs(p_state - r_state))
results[f"state_frame{i}_maxdiff"] = float(max_diff)
if max_diff > 1e-6:
results["pass"] = False
results["errors"].append(
f"State value mismatch at frame {i}: max_diff={max_diff:.2e}"
)
# Compare observation tables
# Note: ported and reference may produce rows in different order (due to
# platform grouping within parquet row-groups). This is benign — the
# downstream transform processes all obs in a window together. We sort
# both tables by a canonical key before value comparison.
import pyarrow.compute as pc
def _sort_obs_table(table):
"""Sort by (Global_Channel_ID, Latitude, Longitude, Absolute_Obs_Time)
to produce a deterministic row order for comparison."""
sort_keys = [
("Global_Channel_ID", "ascending"),
("Latitude", "ascending"),
("Longitude", "ascending"),
("Absolute_Obs_Time", "ascending"),
]
indices = pc.sort_indices(table, sort_keys=sort_keys)
return table.take(indices)
for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)):
p_obs = po.get("obs")
r_obs = ro.get("obs") or ro.get("obs_v2") # reference uses legacy key
if p_obs is None and r_obs is None:
continue
if (p_obs is None) != (r_obs is None):
results["pass"] = False
results["errors"].append(f"Obs presence mismatch at frame {i}")
continue
p_nrows = p_obs.num_rows
r_nrows = r_obs.num_rows
results[f"obs_frame{i}_nrows_ported"] = p_nrows
results[f"obs_frame{i}_nrows_ref"] = r_nrows
if p_nrows != r_nrows:
results["pass"] = False
results["errors"].append(
f"Obs row count mismatch at frame {i}: {p_nrows} vs {r_nrows}"
)
continue
if p_nrows > 0:
# Compare schemas
p_cols = set(p_obs.schema.names)
r_cols = set(r_obs.schema.names)
if p_cols != r_cols:
results["pass"] = False
results["errors"].append(
f"Obs schema mismatch at frame {i}: "
f"ported_only={p_cols - r_cols}, ref_only={r_cols - p_cols}"
)
continue
# Sort both tables to canonical order before comparison
p_sorted = _sort_obs_table(p_obs)
r_sorted = _sort_obs_table(r_obs)
# Compare observation values
p_vals = p_sorted["Observation"].to_numpy()
r_vals = r_sorted["Observation"].to_numpy()
obs_max_diff = np.nanmax(np.abs(p_vals - r_vals))
results[f"obs_frame{i}_val_maxdiff"] = float(obs_max_diff)
if obs_max_diff > 1e-5:
results["pass"] = False
results["errors"].append(
f"Obs value mismatch at frame {i}: max_diff={obs_max_diff:.2e}"
)
# Also verify Global_Channel_ID sets match
p_gcids = set(p_obs["Global_Channel_ID"].to_pylist())
r_gcids = set(r_obs["Global_Channel_ID"].to_pylist())
if p_gcids != r_gcids:
results["pass"] = False
results["errors"].append(
f"Obs GCID set mismatch at frame {i}: "
f"ported_only={p_gcids - r_gcids}, ref_only={r_gcids - p_gcids}"
)
if verbose:
status = "PASS" if results["pass"] else "FAIL"
timing = (
f"ported={results.get('ported_time_s', 0):.2f}s "
f"ref={results.get('ref_time_s', 0):.2f}s"
)
print(f" [{status}] idx={idx:6d} {timing}")
for err in results["errors"]:
print(f" {err}")
return results
def compare_transformed_sample(ported_ds, ref_ds, idx: int, verbose: bool = True):
"""Compare transformed (batched) output between ported and reference.
Uses __getitems__ to exercise the full transform pipeline.
"""
results = {"idx": idx, "pass": True, "errors": []}
try:
t0 = time.time()
ported_batch = ported_ds.__getitems__([idx])
ported_elapsed = time.time() - t0
t0 = time.time()
ref_batch = ref_ds.__getitems__([idx])
ref_elapsed = time.time() - t0
results["ported_transform_s"] = ported_elapsed
results["ref_transform_s"] = ref_elapsed
except Exception as e:
results["pass"] = False
results["errors"].append(f"Transform failed: {e}")
if verbose:
print(f" [FAIL] idx={idx:6d} Transform error: {e}")
return results
# Compare batch dict keys
p_keys = set(ported_batch.keys())
r_keys = set(ref_batch.keys())
if p_keys != r_keys:
results["errors"].append(
f"Batch key mismatch: ported_only={p_keys - r_keys}, ref_only={r_keys - p_keys}"
)
# Don't fail — extra/missing keys may be intentional
# Compare tensor fields
for key in sorted(p_keys & r_keys):
pv = ported_batch[key]
rv = ref_batch[key]
if isinstance(pv, torch.Tensor) and isinstance(rv, torch.Tensor):
if pv.shape != rv.shape:
results["pass"] = False
results["errors"].append(
f"Shape mismatch for '{key}': {pv.shape} vs {rv.shape}"
)
continue
if pv.numel() == 0:
continue
max_diff = (pv.float() - rv.float()).abs().max().item()
results[f"{key}_maxdiff"] = max_diff
# Use loose tolerance for float transforms
tol = 1e-4 if pv.is_floating_point() else 0
if max_diff > tol:
results["pass"] = False
results["errors"].append(
f"Value mismatch for '{key}': max_diff={max_diff:.2e}"
)
elif isinstance(pv, tuple) and isinstance(rv, tuple):
# unified_obs is a tuple (obs_tensors, lengths_3d)
# Row ordering may differ between ported and reference (benign —
# within each sensor group, platforms can appear in different order
# depending on parquet row-group layout). We sort both by
# (global_channel_id, latitude, longitude) before comparing values.
if len(pv) != len(rv):
results["pass"] = False
results["errors"].append(
f"Tuple length mismatch for '{key}': {len(pv)} vs {len(rv)}"
)
continue
if isinstance(pv[0], dict) and isinstance(rv[0], dict):
p_obs_keys = set(pv[0].keys())
r_obs_keys = set(rv[0].keys())
if p_obs_keys != r_obs_keys:
results["errors"].append(
f"Obs tensor key mismatch: "
f"ported_only={p_obs_keys - r_obs_keys}, "
f"ref_only={r_obs_keys - p_obs_keys}"
)
# Build a stable sort index using torch.lexsort-style
# multi-key sorting: (gcid, abs_time, lat, lon, observation)
def _sort_idx(obs_dict):
gcid = obs_dict.get("global_channel_id")
lat = obs_dict.get("latitude")
lon = obs_dict.get("longitude")
obs_time = obs_dict.get("absolute_obs_time")
obs_val = obs_dict.get("observation")
if gcid is None or gcid.numel() == 0:
return None
# Stack columns as (N, K) float64 for lexicographic sort.
# torch.lexsort isn't available, so we use numpy.
cols = [gcid.double().cpu().numpy()]
if obs_time is not None:
cols.append(obs_time.double().cpu().numpy())
if lat is not None:
cols.append(lat.double().cpu().numpy())
if lon is not None:
cols.append(lon.double().cpu().numpy())
if obs_val is not None:
cols.append(obs_val.double().cpu().numpy())
# np.lexsort sorts by last key first, so reverse
order = np.lexsort(cols[::-1])
return torch.from_numpy(order).long()
p_order = _sort_idx(pv[0])
r_order = _sort_idx(rv[0])
for obs_key in sorted(p_obs_keys & r_obs_keys):
pt = pv[0][obs_key]
rt = rv[0][obs_key]
if pt.shape != rt.shape:
results["pass"] = False
results["errors"].append(
f"Obs tensor shape mismatch for '{obs_key}': "
f"{pt.shape} vs {rt.shape}"
)
elif pt.numel() > 0:
# Apply sort order before comparison
ps = pt[p_order] if p_order is not None else pt
rs = rt[r_order] if r_order is not None else rt
d = (ps.float() - rs.float()).abs().max().item()
results[f"obs_{obs_key}_maxdiff"] = d
if d > 1e-4:
results["pass"] = False
results["errors"].append(
f"Obs tensor mismatch for '{obs_key}': "
f"max_diff={d:.2e}"
)
# Compare lengths_3d (sensor, batch, time) — these count obs per
# sensor/window and are order-independent as long as sensor_id
# mapping matches.
for ti, name in [(1, "lengths")]:
if ti < len(pv) and ti < len(rv):
pt, rt = pv[ti], rv[ti]
if isinstance(pt, torch.Tensor) and isinstance(rt, torch.Tensor):
if pt.shape != rt.shape:
results["pass"] = False
results["errors"].append(
f"{name} shape mismatch: {pt.shape} vs {rt.shape}"
)
elif not torch.equal(pt, rt):
results["pass"] = False
results["errors"].append(f"{name} value mismatch")
if verbose:
status = "PASS" if results["pass"] else "FAIL"
timing = (
f"ported={results.get('ported_transform_s', 0):.2f}s "
f"ref={results.get('ref_transform_s', 0):.2f}s"
)
print(f" [{status}] idx={idx:6d} {timing}")
for err in results["errors"]:
print(f" {err}")
return results
# ============================================================================
# Main driver
# ============================================================================
def get_indices(mode: str, ds_len: int, custom_indices=None):
"""Return sample indices based on mode."""
if custom_indices:
return [i for i in custom_indices if i < ds_len]
if mode == "smoketest":
# 3 indices: start, middle, near end
return [0, ds_len // 2, ds_len - 1]
elif mode == "full":
# 50 indices spread across the dataset
n = min(50, ds_len)
step = max(1, ds_len // n)
return list(range(0, ds_len, step))[:n]
else:
raise ValueError(f"Unknown mode: {mode}")
def main():
parser = argparse.ArgumentParser(
description="Compare ported vs reference data loader outputs."
)
parser.add_argument(
"mode",
choices=["smoketest", "full"],
help="smoketest: 3 indices, fast. full: 50 indices.",
)
parser.add_argument(
"--indices",
type=int,
nargs="*",
default=None,
help="Override indices to compare.",
)
parser.add_argument(
"--split", default="train", help="Dataset split (default: train)."
)
parser.add_argument(
"--sensors",
nargs="*",
default=None,
help="Sensor list (default: atms mhs amsua amsub).",
)
parser.add_argument(
"--transform",
action="store_true",
help="Also compare transformed (__getitems__) output.",
)
parser.add_argument(
"--no-raw",
action="store_true",
help="Skip raw (get) comparison, only do transform.",
)
args = parser.parse_args()
print("=" * 70)
print(f"Loader comparison — mode={args.mode}, split={args.split}")
print("=" * 70)
# Build datasets
print("\nBuilding ported dataset...")
t0 = time.time()
ported_ds = build_ported_dataset(split=args.split, sensors=args.sensors)
print(f" Done in {time.time() - t0:.1f}s (len={len(ported_ds)})")
print("Building reference dataset...")
t0 = time.time()
ref_ds = build_reference_dataset(split=args.split, sensors=args.sensors)
print(f" Done in {time.time() - t0:.1f}s (len={len(ref_ds)})")
# Verify lengths match
if len(ported_ds) != len(ref_ds):
print(
f"\nWARNING: Dataset lengths differ! "
f"ported={len(ported_ds)} vs ref={len(ref_ds)}"
)
indices = get_indices(
args.mode, min(len(ported_ds), len(ref_ds)), args.indices
)
print(f"\nComparing {len(indices)} samples: {indices[:10]}{'...' if len(indices) > 10 else ''}")
# --- Raw comparison ---
if not args.no_raw:
print(f"\n--- Raw comparison (get) ---")
raw_results = []
for idx in indices:
r = compare_single_sample(ported_ds, ref_ds, idx)
raw_results.append(r)
n_pass = sum(1 for r in raw_results if r["pass"])
n_fail = len(raw_results) - n_pass
print(f"\nRaw: {n_pass}/{len(raw_results)} passed, {n_fail} failed")
# --- Transform comparison ---
if args.transform:
print(f"\n--- Transform comparison (__getitems__) ---")
xform_results = []
for idx in indices:
r = compare_transformed_sample(ported_ds, ref_ds, idx)
xform_results.append(r)
n_pass = sum(1 for r in xform_results if r["pass"])
n_fail = len(xform_results) - n_pass
print(f"\nTransform: {n_pass}/{len(xform_results)} passed, {n_fail} failed")
# --- Summary ---
print("\n" + "=" * 70)
all_results = []
if not args.no_raw:
all_results.extend(raw_results)
if args.transform:
all_results.extend(xform_results)
n_total = len(all_results)
n_pass = sum(1 for r in all_results if r["pass"])
if n_total == n_pass:
print(f"ALL {n_total} CHECKS PASSED")
else:
print(f"{n_total - n_pass}/{n_total} CHECKS FAILED")
sys.exit(1)
if __name__ == "__main__":
main() |
aayushg55
left a comment
There was a problem hiding this comment.
Took a final pass. Looks good, I should be able to extend the obs transform easily. Just need to update the requirements and rename the platform field.
Greptile SummaryThis PR introduces the
Important Files Changed
Reviews (1): Last reviewed commit: "Revert precommit change" | Re-trigger Greptile |
| segments = [] | ||
| start = 0 | ||
| for i in range(1, x.size): |
There was a problem hiding this comment.
IndexError for single-element time array
d = x[1] - x[0] is evaluated before the loop for every non-empty array, so passing an array with exactly one element (x.size == 1) raises IndexError: index 1 is out of bounds for axis 0 with size 1. Only x.size == 0 is guarded. Any dataset constructed with a single timestamp (common in unit tests or small experiments) will crash here.
| segments = [] | |
| start = 0 | |
| for i in range(1, x.size): | |
| if x.size <= 1: | |
| return [x] if x.size == 1 else [] | |
| d = x[1] - x[0] |
|
|
||
| this_window = None | ||
| for w in target_windows: | ||
| if row_group_lo <= w <= row_group_hi: | ||
| this_window = w | ||
|
|
||
| if this_window is None: | ||
| continue | ||
|
|
||
| table = parquet.read_row_group( | ||
| row_group_idx, columns=self._read_columns | ||
| ) | ||
|
|
||
| if row_group_lo != row_group_hi: | ||
| mask = pc.is_in( | ||
| table["DA_window"], pa.array(list(target_windows)) | ||
| ) | ||
| table = table.filter(mask) | ||
|
|
||
| if table.num_rows == 0: | ||
| continue | ||
|
|
||
| yield this_window, table |
There was a problem hiding this comment.
Multi-window row groups silently merge under wrong key
When row_group_lo != row_group_hi, multiple target windows may satisfy row_group_lo <= w <= row_group_hi. The inner loop overwrites this_window on each match, so only the last matching window is used. The table is filtered to keep rows from all matching windows, then yielded once under that last-window label.
The downstream tables.setdefault(interval_time, []).append(table) in sel_time therefore stores a multi-window table under only one key; the other matching windows get [] when looked up in process(t). Observations from earlier DA windows end up attributed to the wrong time step and are dropped entirely for their actual slot.
A fix is to yield one sub-table per matching window:
for w in target_windows:
if row_group_lo <= w <= row_group_hi:
if row_group_lo != row_group_hi:
mask = pc.equal(table["DA_window"], w)
sub = table.filter(mask)
else:
sub = table
if sub.num_rows > 0:
yield w, sub| offset = 0 | ||
| for i in range(len(sensor_id)): | ||
| if sensor_id[i] != sensor_id[i - 1]: | ||
| offset = i | ||
| local_channel_ids.append(i - offset) |
There was a problem hiding this comment.
Off-by-one via Python negative index at
i = 0
At the first loop iteration (i = 0), sensor_id[i - 1] evaluates to sensor_id[-1] — the last element of the array — due to Python's wrap-around indexing. If the first and last rows share the same sensor_id, offset is never reset, and the first channel gets a local_channel_id offset from 0 instead of 0.
The intent is to assign a 0-based local ID within each contiguous sensor block. A clearer, bug-free approach:
| offset = 0 | |
| for i in range(len(sensor_id)): | |
| if sensor_id[i] != sensor_id[i - 1]: | |
| offset = i | |
| local_channel_ids.append(i - offset) | |
| for i in range(len(sensor_id)): | |
| if i == 0 or sensor_id[i] != sensor_id[i - 1]: | |
| offset = i | |
| local_channel_ids.append(i - offset) |
There was a problem hiding this comment.
Technically, this can never lead to a bug since the offset is initialized as 0, so in either case, the offset is 0 at i=0. But agree this would better reflect the intent
| def _normalize_observations(self, table: pa.Table) -> pa.Table: | ||
| if self.normalization == "fixed_range": | ||
| normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0) | ||
| elif self.normalization == "zscore": |
There was a problem hiding this comment.
fixed_range normalization ignores per-channel min_valid/max_valid
The fixed_range branch hard-codes 0 and 400 as global min/max:
normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0)The channel table already supplies min_valid and max_valid per channel (loaded in _add_channel_metadata), but they are not used here. This likely produces incorrectly scaled observations for any channel whose valid range differs from [0, 400], which is almost every instrument (e.g. AMSUA channel 1 mean ≈ 208 K, see amsua_normalizations.csv).
There was a problem hiding this comment.
We only use zscore normalization, I don't think it would ever make sense to use anything else. Probably best to remove the deprecated code.
There was a problem hiding this comment.
Agreed. Here's what to clean up:
- Remove the
normalizationparameter from__init__(and its docstring entry) - Drop
self.normalizationassignment - Simplify
_normalize_observationsto always use zscore - Update
get_channel_tableto drop the now-unnecessarynormalization="zscore"kwarg
The simplified method would look like:
| def _normalize_observations(self, table: pa.Table) -> pa.Table: | |
| if self.normalization == "fixed_range": | |
| normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0) | |
| elif self.normalization == "zscore": | |
| def _normalize_observations(self, table: pa.Table) -> pa.Table: | |
| normalized = pc.divide( | |
| pc.subtract(table["Observation"], table["mean"]), | |
| table["stddev"], | |
| ) |
And remove this line from the __init__ signature:
normalization: Literal["fixed_range", "zscore"] = "fixed_range",Note: the default was "fixed_range", so any existing callers not explicitly passing normalization="zscore" were silently using the broken path. Removing the param entirely is safe since zscore is the only correct behavior.
Tip: You can customize Greptile's behavior for this repo with .greptile/rules and .greptile/config.
| self.iteration = 0 | ||
| rng = torch.Generator().manual_seed(self.seed + self.epoch + self.rank) | ||
| permutation = torch.randperm(self.len, generator=rng) |
There was a problem hiding this comment.
Rank-specific seeds break uniform dataset coverage
Each rank seeds its RNG with self.seed + self.epoch + self.rank, so rank 0 and rank 1 generate different random permutations. Rank 0 then takes indices [0, num_replicas, 2*num_replicas, …] of its permutation, while rank 1 does the same from a different permutation. The combined set of indices seen by all ranks in one epoch is not a partition of [0, dataset_size) — samples may be duplicated across ranks or missed entirely.
The standard approach (PyTorch's DistributedSampler) uses the same seed on all ranks so the permutation is identical, then partitions it by rank. This PR intentionally diverges from that pattern (as evidenced by the test_multi_replica_independent test), which means per-epoch sample coverage is not guaranteed. This is an intentional design choice but should be documented explicitly in the class docstring so users understand the trade-off.
There was a problem hiding this comment.
Good catch. Agree that given we assign tasks per rank based on self.permutation = permutation[self.rank :: self.num_replicas], with the intention that all have the same permutation to begin with. So we should drop the +self.rank.
However, the interpretation of test_multi_replica_independent is wrong; it is simply covering that different ranks recieve different tasks.
| ) | ||
|
|
||
| segment_idx = 0 | ||
| for i, cum_size in enumerate(self.cumulative_valid_sizes[1:], 1): | ||
| if logical_idx < cum_size: | ||
| segment_idx = i - 1 | ||
| break | ||
|
|
There was a problem hiding this comment.
Bounds check uses
total_samples instead of valid_length
if logical_idx >= self.total_samples:
raise IndexError(...)total_samples is the sum of all segment sizes (including frames that can't start a valid window), while logical_idx is always expected to be in [0, valid_length). A value in [valid_length, total_samples) passes this guard silently and maps to a physical frame index that lies inside the last segment but beyond the last valid window, potentially yielding out-of-bounds xarray .isel() coordinates.
| ) | |
| segment_idx = 0 | |
| for i, cum_size in enumerate(self.cumulative_valid_sizes[1:], 1): | |
| if logical_idx < cum_size: | |
| segment_idx = i - 1 | |
| break | |
| if logical_idx >= self.valid_length: | |
| raise IndexError( | |
| f"Sample index {logical_idx} out of bounds " | |
| f"for {self.valid_length} valid samples" | |
| ) |
There was a problem hiding this comment.
Yes, we should replace the bounds check to be against self.valid_length instead
PhysicsNeMo Pull Request
Description
Adds the HealDA data loader system to
physicsnemo.experimental.datapipes. Focused initially on reproducibility, preserving performance features NVR developed, and establishing clear interfaces for users interested in extending with custom data.Also brings in the unit testing for these components, currently living in the recipe folder.
Documentation is mostly in the recipe readme, some copied here for reference:
The
physicsnemo.experimental.datapipes.healdapackage provides a composable data loading pipeline with clear extension points. The architecture separates components into loaders, transforms, datasets, and sampling infrastructure.Architecture
Key Protocols
Custom data sources and transforms plug in via these protocols
(see
physicsnemo.experimental.datapipes.healda.protocols):ObsLoader— the observation loading interface:Transform/DeviceTransform— two-stage batchprocessing:
Provided Implementations
ObsERA5DatasetdatasetUFSUnifiedLoaderloaders.ufs_obsERA5Loaderloaders.era5ERA5ObsTransformtransforms.era5_obsRestartableDistributedSamplersamplersprefetch_mapprefetchAll modules above are under
physicsnemo.experimental.datapipes.healda.Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.