diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4dc7972f4..8ec90e4b2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,6 +11,8 @@ jobs: # github.repository == 'UXARRAY/uxarray' name: Python (${{ matrix.python-version }}, ${{ matrix.os }}) runs-on: ${{ matrix.os }} + env: + MPLBACKEND: Agg defaults: run: shell: bash -l {0} diff --git a/.github/workflows/yac-optional.yml b/.github/workflows/yac-optional.yml new file mode 100644 index 000000000..ab0f7003b --- /dev/null +++ b/.github/workflows/yac-optional.yml @@ -0,0 +1,142 @@ +name: YAC Optional CI + +on: + pull_request: + paths: + - ".github/workflows/yac-optional.yml" + - "uxarray/remap/**" + - "test/test_remap_yac.py" + workflow_dispatch: + +jobs: + yac-optional: + name: YAC core v3.14.0_p1 (Ubuntu) + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + env: + YAC_VERSION: v3.14.0_p1 + YAXT_VERSION: v0.11.5.1 + MPIEXEC: /usr/bin/mpirun + MPIRUN: /usr/bin/mpirun + MPICC: /usr/bin/mpicc + MPIFC: /usr/bin/mpif90 + MPIF90: /usr/bin/mpif90 + OMPI_ALLOW_RUN_AS_ROOT: 1 + OMPI_ALLOW_RUN_AS_ROOT_CONFIRM: 1 + steps: + - name: checkout + uses: actions/checkout@v4 + with: + token: ${{ github.token }} + + - name: conda_setup + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: uxarray_build + channel-priority: strict + python-version: "3.11" + channels: conda-forge + environment-file: ci/environment.yml + miniforge-variant: Miniforge3 + miniforge-version: latest + + - name: Install build dependencies (apt) + run: | + sudo apt-get update + sudo apt-get install -y \ + autoconf \ + automake \ + gawk \ + gfortran \ + libopenmpi-dev \ + libtool \ + make \ + openmpi-bin \ + pkg-config + - name: Verify MPI tools + run: | + which mpirun + which mpicc + which mpif90 + mpirun --version + mpicc --version + mpif90 --version + - name: Install Python build dependencies + run: | + python -m pip install --upgrade pip + python -m pip install cython wheel + - name: Build and install YAXT + run: | + set -euxo pipefail + YAC_PREFIX="${GITHUB_WORKSPACE}/yac_prefix" + echo "YAC_PREFIX=${YAC_PREFIX}" >> "${GITHUB_ENV}" + git clone --depth 1 --branch "${YAXT_VERSION}" https://gitlab.dkrz.de/dkrz-sw/yaxt.git + if [ ! -x yaxt/configure ]; then + if [ -x yaxt/autogen.sh ]; then + (cd yaxt && ./autogen.sh) + else + (cd yaxt && autoreconf -i) + fi + fi + mkdir -p yaxt/build + cd yaxt/build + ../configure \ + --prefix="${YAC_PREFIX}" \ + --without-regard-for-quality \ + CC="${MPICC}" \ + FC="${MPIF90}" + make -j2 + make install + - name: Build and install YAC + run: | + set -euxo pipefail + git clone --depth 1 --branch "${YAC_VERSION}" https://gitlab.dkrz.de/dkrz-sw/yac.git + if [ ! -x yac/configure ]; then + if [ -x yac/autogen.sh ]; then + (cd yac && ./autogen.sh) + else + (cd yac && autoreconf -i) + fi + fi + mkdir -p yac/build + cd yac/build + ../configure \ + --prefix="${YAC_PREFIX}" \ + --with-yaxt-root="${YAC_PREFIX}" \ + --disable-mci \ + --disable-utils \ + --disable-examples \ + --disable-tools \ + --disable-netcdf \ + --enable-python-bindings \ + CC="${MPICC}" \ + FC="${MPIF90}" + make -j2 + make install + - name: Configure YAC runtime paths + run: | + set -euxo pipefail + PY_VER="$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')" + echo "LD_LIBRARY_PATH=${YAC_PREFIX}/lib:${LD_LIBRARY_PATH:-}" >> "${GITHUB_ENV}" + echo "PYTHONPATH=${YAC_PREFIX}/lib/python${PY_VER}/site-packages:${YAC_PREFIX}/lib/python${PY_VER}/dist-packages:${PYTHONPATH:-}" >> "${GITHUB_ENV}" + - name: Verify YAC core Python bindings + run: | + python - <<'PY' + from pathlib import Path + import sys + candidates = [] + for entry in sys.path: + pkg = Path(entry) / "yac" + candidates.extend(pkg.glob("core*.so")) + candidates.extend(pkg.glob("core*.pyd")) + assert candidates, "yac.core extension not found on sys.path" + print("Found yac.core extension:", candidates[0]) + PY + - name: Install uxarray + run: | + python -m pip install . --no-deps + - name: Run tests (uxarray with YAC) + run: | + python -m pytest test/test_remap_yac.py diff --git a/docs/user-guide/remapping.ipynb b/docs/user-guide/remapping.ipynb index cd187b2e5..1a103ece6 100644 --- a/docs/user-guide/remapping.ipynb +++ b/docs/user-guide/remapping.ipynb @@ -15,7 +15,11 @@ "\n", "- **Nearest Neighbor** \n", "- **Inverse Distance Weighted**\n", - "- **Bilinear**\n" + "- **Bilinear**\n", + "\n", + "UXarray uses its native remapping backend by default. For `.remap(...)`, `.remap.nearest_neighbor(...)`, and `.remap.bilinear(...)`, you can also set `backend=\"yac\"` to route the operation through YAC when `yac.core` is installed.\n", + "\n", + "When `backend=\"yac\"`, the `yac_method` parameter selects the YAC interpolation method. Supported values are `nnn`, `average`, and `conservative`. `inverse_distance_weighted()` remains UXarray-only, and `bilinear(..., backend=\"yac\")` is a convenience wrapper for `yac_method=\"average\"`.\n" ] }, { @@ -132,6 +136,14 @@ "- **remap_to** \n", " The grid element where values should be placed, one of `faces`, `edges`, or `nodes`.\n", "\n", + "- **backend** \n", + " The remapping backend to use. The default is `\"uxarray\"`; set `backend=\"yac\"` to route the remap through YAC.\n", + "\n", + "- **yac_method** \n", + " Required only when `backend=\"yac\"`. Supported values are `nnn`, `average`, and `conservative`; `nearest_neighbor()` defaults to `nnn`.\n", + "\n", + "- **yac_options** \n", + " Optional dictionary of method-specific keyword arguments forwarded to the selected YAC interpolation routine.\n", "\n", "```{warning}\n", "Nearest-neighbor remapping is fast and simple, but it does **not** conserve integrated quantities\n", @@ -480,7 +492,9 @@ "id": "6bec26ce-67b6-4300-a310-63cbac2b289a", "metadata": {}, "source": [ - "Bilinear remapping breaks down the grid into triangles, and then finds the triangle that contains each point on the destinitation grid. It then uses the values stored at each vertex to bilinearly find a value for the point, depending on it's postion inside the triangle." + "Bilinear remapping breaks down the grid into triangles, and then finds the triangle that contains each point on the destinitation grid. It then uses the values stored at each vertex to bilinearly find a value for the point, depending on it's postion inside the triangle.\n", + "\n", + "When `backend=\"yac\"`, `remap.bilinear()` delegates to YAC's `average` method. This is the only YAC method exposed through the bilinear convenience accessor; use `.remap(..., backend=\"yac\", yac_method=...)` if you need to select another YAC method explicitly." ] }, { diff --git a/test/test_remap_yac.py b/test/test_remap_yac.py new file mode 100644 index 000000000..7d71e2656 --- /dev/null +++ b/test/test_remap_yac.py @@ -0,0 +1,258 @@ +import numpy as np +import pytest + +import uxarray as ux +from uxarray.remap.yac import YacNotAvailableError, _import_yac + + +try: + _import_yac() +except YacNotAvailableError: + pytest.skip("yac.core is not available", allow_module_level=True) + + +def test_yac_nnn_node_remap(gridpath, datasetpath): + grid_path = gridpath("ugrid", "geoflow-small", "grid.nc") + uxds = ux.open_dataset(grid_path, datasetpath("ugrid", "geoflow-small", "v1.nc")) + dest = ux.open_grid(grid_path) + + out = uxds["v1"].remap.nearest_neighbor( + destination_grid=dest, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + assert out.size > 0 + assert "n_node" in out.dims + + +def test_yac_conservative_face_remap(gridpath): + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(mesh_path) + + out = uxds["latCell"].remap( + destination_grid=dest, + remap_to="faces", + backend="yac", + yac_method="conservative", + yac_options={"order": 1}, + ) + assert out.size == dest.n_face + + +def test_yac_matches_uxarray_nearest_neighbor(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + ux_out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="uxarray", + ) + yac_out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + assert ux_out.shape == yac_out.shape + assert (ux_out.values == yac_out.values).all() + + +def test_yac_call_defaults_to_nnn(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + out = da.remap( + destination_grid=grid, + remap_to="nodes", + backend="yac", + ) + + assert out.shape == da.shape + np.testing.assert_array_equal(out.values, da.values) + + +def test_yac_invalid_backend_raises(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + with pytest.raises(ValueError, match="Invalid backend"): + da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="bogus", + ) + + +def test_yac_idw_not_implemented(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + with pytest.raises(NotImplementedError, match="inverse_distance_weighted"): + da.remap.inverse_distance_weighted( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + + +def test_yac_bilinear_face_remap(gridpath): + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(mesh_path) + + out = uxds["latCell"].remap.bilinear( + destination_grid=dest, + remap_to="faces", + backend="yac", + ) + + assert out.size == dest.n_face + + +def test_yac_bilinear_rejects_non_average_method(gridpath): + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(mesh_path) + + with pytest.raises(ValueError, match="only supports yac_method='average'"): + uxds["latCell"].remap.bilinear( + destination_grid=dest, + remap_to="faces", + backend="yac", + yac_method="conservative", + ) + + +def test_yac_conservative_rejects_non_face_data(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + with pytest.raises(ValueError, match="face-centered"): + da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="conservative", + yac_options={"order": 1}, + ) + + +def test_yac_preserves_spatial_coordinate_remap(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={ + "n_node": [0, 1, 2], + "node_lon": ( + "n_node", + np.array([0.0, -180.0, 0.0]), + {"standard_name": "longitude", "units": "degrees_east"}, + ), + "node_lat": ( + "n_node", + np.array([90.0, 0.0, -90.0]), + {"standard_name": "latitude", "units": "degrees_north"}, + ), + }, + uxgrid=grid, + ) + + out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + + np.testing.assert_array_equal(out.values, da.values) + assert "node_lon" in out.coords + assert "node_lat" in out.coords + + +def test_yac_batched_remap_with_extra_dimension(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]), + dims=["time", "n_node"], + coords={"time": [0, 1], "n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + + assert out.shape == da.shape + np.testing.assert_array_equal(out.values, da.values) + + +def test_yac_batched_remap_with_fractional_mask(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]), + dims=["time", "n_node"], + coords={"time": [0, 1], "n_node": [0, 1, 2]}, + uxgrid=grid, + ) + frac_mask = np.ones_like(da.values, dtype=np.float64) + + out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={ + "n": 1, + "frac_mask_fallback_value": 0.0, + "frac_mask": frac_mask, + }, + ) + + assert out.shape == da.shape + np.testing.assert_array_equal(out.values, da.values) diff --git a/uxarray/remap/accessor.py b/uxarray/remap/accessor.py index ebf74ffa4..d39f21048 100644 --- a/uxarray/remap/accessor.py +++ b/uxarray/remap/accessor.py @@ -11,6 +11,15 @@ from uxarray.remap.inverse_distance_weighted import _inverse_distance_weighted_remap from uxarray.remap.nearest_neighbor import _nearest_neighbor_remap +_VALID_BACKENDS = ("uxarray", "yac") + + +def _validate_backend(backend: str) -> None: + if backend not in _VALID_BACKENDS: + raise ValueError( + f"Invalid backend '{backend}'. Expected one of {_VALID_BACKENDS}." + ) + class RemapAccessor: """Expose remapping methods on UxDataArray and UxDataset objects.""" @@ -27,17 +36,36 @@ def __repr__(self) -> str: + " • inverse_distance_weighted(destination_grid, remap_to='faces', power=2, k=8)\n" ) - def __call__(self, *args, **kwargs) -> UxDataArray | UxDataset: + def __call__( + self, + *args, + backend: str = "uxarray", + yac_method: str | None = None, + yac_options: dict | None = None, + **kwargs, + ) -> UxDataArray | UxDataset: """ Shortcut for nearest-neighbor remapping. Calling `.remap(...)` with no explicit method will invoke `nearest_neighbor(...)`. + + When ``backend="yac"``, this generic entrypoint can also be used to + select a YAC-specific interpolation method through ``yac_method``. """ - return self.nearest_neighbor(*args, **kwargs) + nn_kwargs: dict = {"backend": backend, "yac_options": yac_options} + if yac_method is not None: + nn_kwargs["yac_method"] = yac_method + return self.nearest_neighbor(*args, **nn_kwargs, **kwargs) def nearest_neighbor( - self, destination_grid: Grid, remap_to: str = "faces", **kwargs + self, + destination_grid: Grid, + remap_to: str = "faces", + backend: str = "uxarray", + yac_method: str | None = "nnn", + yac_options: dict | None = None, + **kwargs, ) -> UxDataArray | UxDataset: """ Perform nearest-neighbor remapping. @@ -51,16 +79,40 @@ def nearest_neighbor( remap_to : {'nodes', 'edges', 'faces'}, default='faces' Which grid element receives the remapped values. + backend : {'uxarray', 'yac'}, default='uxarray' + Remapping backend to use. When set to 'yac', requires YAC to be + available on PYTHONPATH. + yac_method : {'nnn', 'conservative'}, optional + YAC interpolation method. Defaults to 'nnn' when backend='yac'. + yac_options : dict, optional + YAC interpolation configuration options. + Returns ------- UxDataArray or UxDataset A new object with data mapped onto `destination_grid`. """ + _validate_backend(backend) + if backend == "yac": + from uxarray.remap.yac import _yac_remap + + yac_kwargs = yac_options or {} + return _yac_remap( + self.ux_obj, destination_grid, remap_to, yac_method, yac_kwargs + ) return _nearest_neighbor_remap(self.ux_obj, destination_grid, remap_to) def inverse_distance_weighted( - self, destination_grid: Grid, remap_to: str = "faces", power=2, k=8, **kwargs + self, + destination_grid: Grid, + remap_to: str = "faces", + power=2, + k=8, + backend: str = "uxarray", + yac_method: str | None = None, + yac_options: dict | None = None, + **kwargs, ) -> UxDataArray | UxDataset: """ Perform inverse-distance-weighted (IDW) remapping. @@ -80,18 +132,40 @@ def inverse_distance_weighted( k : int, default=8 Number of nearest source points to include in the weighted average. + backend : {'uxarray', 'yac'}, default='uxarray' + Remapping backend to use. When set to 'yac', requires YAC to be + available on PYTHONPATH. + yac_method : {'nnn', 'conservative'}, optional + YAC interpolation method. Required when backend='yac'. + yac_options : dict, optional + YAC interpolation configuration options. + Returns ------- UxDataArray or UxDataset A new object with data mapped onto `destination_grid`. """ + _validate_backend(backend) + if backend == "yac": + raise NotImplementedError( + "inverse_distance_weighted with backend='yac' is not currently " + "exposed through the UXarray YAC accessor. " + "Use backend='uxarray' for IDW, or use the YAC backend through " + ".remap(..., backend='yac', yac_method=..., yac_options=...)." + ) return _inverse_distance_weighted_remap( self.ux_obj, destination_grid, remap_to, power, k ) def bilinear( - self, destination_grid: Grid, remap_to: str = "faces", **kwargs + self, + destination_grid: Grid, + remap_to: str = "faces", + backend: str = "uxarray", + yac_method: str | None = "average", + yac_options: dict | None = None, + **kwargs, ) -> UxDataArray | UxDataset: """ Perform bilinear remapping. @@ -103,10 +177,36 @@ def bilinear( remap_to : {'nodes', 'edges', 'faces'}, default='faces' Which grid element receives the remapped values. + backend : {'uxarray', 'yac'}, default='uxarray' + Remapping backend to use. When set to 'yac', bilinear remapping is + routed through YAC's average interpolation. + yac_method : {'average'}, optional + YAC interpolation method for the bilinear convenience wrapper. + Only ``'average'`` is supported here. + yac_options : dict, optional + YAC interpolation configuration options for the average method. + Returns ------- UxDataArray or UxDataset A new object with data mapped onto `destination_grid`. """ + _validate_backend(backend) + if backend == "yac": + from uxarray.remap.yac import _yac_remap + + if yac_method not in (None, "average"): + raise ValueError( + "bilinear with backend='yac' only supports yac_method='average'. " + "Use .remap(..., backend='yac', yac_method=...) for other YAC methods." + ) + yac_kwargs = yac_options or {} + return _yac_remap( + self.ux_obj, + destination_grid, + remap_to, + yac_method or "average", + yac_kwargs, + ) return _bilinear(self.ux_obj, destination_grid, remap_to) diff --git a/uxarray/remap/yac.py b/uxarray/remap/yac.py new file mode 100644 index 000000000..4bbe3899b --- /dev/null +++ b/uxarray/remap/yac.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +import importlib +import importlib.util +import sys +from dataclasses import dataclass +from pathlib import Path +from types import ModuleType +from typing import Any +from uuid import uuid4 + +import numpy as np + +import uxarray.core.dataarray +from uxarray.remap.utils import ( + LABEL_TO_COORD, + _assert_dimension, + _construct_remapped_ds, + _get_remap_dims, + _to_dataset, +) + + +class YacNotAvailableError(RuntimeError): + """Raised when the YAC backend is requested but unavailable.""" + + +@dataclass +class _YacOptions: + method: str + kwargs: dict[str, Any] + + +def _load_yac_core_from_file() -> ModuleType | None: + if "yac.core" in sys.modules: + return sys.modules["yac.core"] + + for path_entry in sys.path: + pkg_dir = Path(path_entry) / "yac" + if not pkg_dir.is_dir(): + continue + + matches = sorted(pkg_dir.glob("core*.so")) + if not matches: + matches = sorted(pkg_dir.glob("core*.pyd")) + if not matches: + continue + + pkg = sys.modules.get("yac") + if pkg is None: + pkg = ModuleType("yac") + sys.modules["yac"] = pkg + pkg.__path__ = [str(pkg_dir)] + + spec = importlib.util.spec_from_file_location("yac.core", matches[0]) + if spec is None or spec.loader is None: + continue + + module = importlib.util.module_from_spec(spec) + sys.modules["yac.core"] = module + spec.loader.exec_module(module) + setattr(pkg, "core", module) + return module + + return None + + +def _import_yac(): + module = _load_yac_core_from_file() + if module is not None: + return module + + try: + return importlib.import_module("yac.core") + except Exception as exc: # pragma: no cover - fallback depends on local install + raise YacNotAvailableError( + "YAC backend requested but 'yac.core' is not available. " + "Build YAC with Python bindings and ensure its site-packages and " + "shared libraries are discoverable." + ) from exc + + +def _normalize_yac_method(yac_method: str | None) -> _YacOptions: + if not yac_method: + raise ValueError( + "backend='yac' requires yac_method to be set to 'nnn', 'average', or 'conservative'." + ) + method = yac_method.lower() + if method not in {"nnn", "average", "conservative"}: + raise ValueError(f"Unsupported YAC method: {yac_method!r}") + return _YacOptions(method=method, kwargs={}) + + +def _get_location(yac_core, dim: str): + mapping = { + "n_face": yac_core.yac_location.YAC_LOC_CELL, + "n_node": yac_core.yac_location.YAC_LOC_CORNER, + "n_edge": yac_core.yac_location.YAC_LOC_EDGE, + } + try: + return mapping[dim] + except KeyError as exc: + raise ValueError(f"Unsupported remap dimension for YAC: {dim!r}") from exc + + +def _get_lon_lat(grid, dim: str) -> tuple[np.ndarray, np.ndarray]: + attr_map = { + "n_face": ("face_lon", "face_lat"), + "n_node": ("node_lon", "node_lat"), + "n_edge": ("edge_lon", "edge_lat"), + } + try: + lon_attr, lat_attr = attr_map[dim] + except KeyError as exc: + raise ValueError(f"Unsupported remap dimension for YAC: {dim!r}") from exc + + lon = getattr(grid, lon_attr, None) + lat = getattr(grid, lat_attr, None) + if lon is None or lat is None: + raise ValueError( + f"Grid does not provide {lon_attr}/{lat_attr} required for YAC remapping." + ) + return np.deg2rad(np.asarray(lon.values, dtype=np.float64)), np.deg2rad( + np.asarray(lat.values, dtype=np.float64) + ) + + +def _coerce_enum(enum_type, value: Any): + if not isinstance(value, str): + return value + + normalized = value.upper() + for member in enum_type: + if member.name == normalized or member.name.endswith(f"_{normalized}"): + return member + + raise ValueError(f"Unsupported value {value!r} for enum {enum_type.__name__}.") + + +class _YacRemapper: + """Build and reuse YAC interpolation weights for one source dimension. + + Each instance owns the YAC source/target field registration for a single + source location type (faces, nodes, or edges) and one requested YAC method. + The resulting weights can then be applied repeatedly to batches of values + that share the same source dimension. + """ + + def __init__( + self, + src_grid, + tgt_grid, + src_dim: str, + tgt_dim: str, + yac_method: str, + yac_kwargs: dict[str, Any], + ): + yac_core = _import_yac() + self._frac_mask_fallback_value = yac_kwargs.get("frac_mask_fallback_value") + self._src_location = _get_location(yac_core, src_dim) + self._tgt_location = _get_location(yac_core, tgt_dim) + + define_edges = "n_edge" in (src_dim, tgt_dim) + unique = uuid4().hex + self._src_grid = yac_core.BasicGrid.from_uxgrid( + f"uxarray_src_{unique}", + src_grid, + def_edges=define_edges, + ) + self._tgt_grid = yac_core.BasicGrid.from_uxgrid( + f"uxarray_tgt_{unique}", + tgt_grid, + def_edges=define_edges, + ) + src_lon, src_lat = _get_lon_lat(src_grid, src_dim) + tgt_lon, tgt_lat = _get_lon_lat(tgt_grid, tgt_dim) + + self._src_field = yac_core.InterpField( + self._src_grid.add_coordinates(self._src_location, src_lon, src_lat) + ) + self._tgt_field = yac_core.InterpField( + self._tgt_grid.add_coordinates(self._tgt_location, tgt_lon, tgt_lat) + ) + + stack = yac_core.InterpolationStack() + if yac_method == "nnn": + weight_type = _coerce_enum( + yac_core.yac_interp_nnn_weight_type, + yac_kwargs.get("reduction_type", yac_kwargs.get("nnn_type")), + ) + if weight_type is None: + weight_type = yac_core.yac_interp_nnn_weight_type.YAC_INTERP_NNN_AVG + stack.add_nnn( + nnn_type=weight_type, + n=yac_kwargs.get("n", 1), + max_search_distance=yac_kwargs.get("max_search_distance", 0.0), + scale=yac_kwargs.get("scale", 1.0), + ) + elif yac_method == "average": + reduction_type = _coerce_enum( + yac_core.yac_interp_avg_weight_type, + yac_kwargs.get("reduction_type", yac_kwargs.get("weight_type")), + ) + if reduction_type is None: + reduction_type = ( + yac_core.yac_interp_avg_weight_type.YAC_INTERP_AVG_ARITHMETIC + ) + stack.add_average( + reduction_type=reduction_type, + partial_coverage=yac_kwargs.get("partial_coverage", False), + ) + elif yac_method == "conservative": + normalisation = _coerce_enum( + yac_core.yac_interp_method_conserv_normalisation, + yac_kwargs.get("normalisation"), + ) + if normalisation is None: + normalisation = yac_core.yac_interp_method_conserv_normalisation.YAC_INTERP_CONSERV_DESTAREA + stack.add_conservative( + order=yac_kwargs.get("order", 1), + enforced_conserv=yac_kwargs.get("enforced_conserv", False), + partial_coverage=yac_kwargs.get("partial_coverage", False), + normalisation=normalisation, + ) + fixed_value = yac_kwargs.get("fixed_value", 0.0) + if fixed_value is not None: + stack.add_fixed(float(fixed_value)) + + self._weights = yac_core.compute_weights( + stack, + self._src_field, + self._tgt_field, + ) + self._interpolations: dict[int, Any] = {} + self._src_size = self._src_grid.get_data_size(self._src_location) + self._tgt_size = self._tgt_grid.get_data_size(self._tgt_location) + + def apply( + self, values: np.ndarray, frac_mask: np.ndarray | None = None + ) -> np.ndarray: + """Apply the pre-computed interpolation weights to *values*. + + The interpolation method (NNN or conservative) is determined by + *yac_method* passed to the constructor and is fixed for the lifetime of + this remapper instance. This method simply executes the weight + application; it does not select or alter the interpolation algorithm. + + Parameters + ---------- + values : np.ndarray + 1-D or 2-D array of source-grid values. The trailing dimension must + equal the number of source points registered with YAC + (``self._src_size``). When 2-D, the leading dimension is treated as + the YAC collection size and is remapped in one batched call. + frac_mask : np.ndarray, optional + Optional fractional source mask with the same shape as ``values``. + When provided, it is forwarded to YAC's interpolation call. + + Returns + ------- + np.ndarray + Array of remapped values on the destination grid with the same + number of leading collections as the input. + """ + values = np.ascontiguousarray(values, dtype=np.float64) + if values.ndim == 1: + values = values.reshape(1, -1) + elif values.ndim != 2: + raise ValueError( + f"YAC remap expects a 1-D or 2-D array, got {values.ndim}-D input." + ) + if values.shape[1] != self._src_size: + raise ValueError( + f"YAC remap expects {self._src_size} values, got {values.shape[1]}." + ) + + if frac_mask is not None: + frac_mask = np.ascontiguousarray(frac_mask, dtype=np.float64) + if frac_mask.ndim == 1: + frac_mask = frac_mask.reshape(1, -1) + elif frac_mask.ndim != 2: + raise ValueError( + "YAC fractional mask expects a 1-D or 2-D array, " + f"got {frac_mask.ndim}-D input." + ) + if frac_mask.shape != values.shape: + raise ValueError( + "YAC fractional mask must match remap input shape. " + f"Got mask shape {frac_mask.shape} and value shape {values.shape}." + ) + + collection_size = values.shape[0] + interpolation = self._interpolations.get(collection_size) + if interpolation is None: + interpolation = self._weights.get_interpolation( + collection_size=collection_size, + frac_mask_fallback_value=self._frac_mask_fallback_value, + ) + self._interpolations[collection_size] = interpolation + + out = ( + interpolation(values, frac_mask=frac_mask) + if frac_mask is not None + else interpolation(values) + ) + return np.asarray(out, dtype=np.float64) + + +def _prepare_frac_mask(frac_mask, da_t, src_values, src_dim: str) -> np.ndarray: + """Normalize a fractional mask to the flattened shape expected by YAC.""" + if hasattr(frac_mask, "dims"): + other_dims = [d for d in da_t.dims if d != src_dim] + frac_mask_values = np.asarray(frac_mask.transpose(*other_dims, src_dim).values) + else: + frac_mask_values = np.asarray(frac_mask) + + if frac_mask_values.shape != src_values.shape: + raise ValueError( + "YAC fractional mask must match the remapped source variable shape. " + f"Got mask shape {frac_mask_values.shape} and source shape {src_values.shape}." + ) + return frac_mask_values.reshape(-1, frac_mask_values.shape[-1]) + + +def _yac_remap(source, destination_grid, remap_to: str, yac_method: str, yac_kwargs): + """Remap a UXarray object through YAC and reconstruct the UXarray result. + + This is the main integration boundary between the public UXarray remap + accessor and the lower-level ``yac.core`` bindings. It normalizes the + requested YAC method, validates method-specific constraints, batches each + remapped variable by its source dimension, and returns a remapped + ``UxDataArray`` or ``UxDataset`` with UXarray metadata preserved. + """ + _assert_dimension(remap_to) + destination_dim = LABEL_TO_COORD[remap_to] + options = _normalize_yac_method(yac_method) + options.kwargs.update(yac_kwargs or {}) + ds, is_da, name = _to_dataset(source) + dims_to_remap = _get_remap_dims(ds) + + if options.method == "conservative": + if destination_dim != "n_face": + raise ValueError( + "YAC conservative remapping requires the destination to be " + "face-centered (remap_to='faces'). " + f"Got remap_to={remap_to!r} which maps to dimension {destination_dim!r}." + ) + non_face_src = dims_to_remap - {"n_face"} + if non_face_src: + raise ValueError( + "YAC conservative remapping requires all source data to be " + f"face-centered (dimension 'n_face'). " + f"Found non-face source dimension(s): {non_face_src}. " + "Use yac_method='nnn' for node- or edge-centered data." + ) + remappers: dict[str, _YacRemapper] = {} + remapped_vars = {} + + for src_dim in dims_to_remap: + remappers[src_dim] = _YacRemapper( + ds.uxgrid, + destination_grid, + src_dim, + destination_dim, + options.method, + options.kwargs, + ) + + for var_name, da in ds.data_vars.items(): + src_dim = next((d for d in da.dims if d in dims_to_remap), None) + if src_dim is None: + remapped_vars[var_name] = da + continue + + other_dims = [d for d in da.dims if d != src_dim] + da_t = da.transpose(*other_dims, src_dim) + src_values = np.asarray(da_t.values) + flat_src = src_values.reshape(-1, src_values.shape[-1]) + frac_masks = yac_kwargs.get("frac_masks") + frac_mask = ( + frac_masks.get(var_name) + if isinstance(frac_masks, dict) and var_name in frac_masks + else yac_kwargs.get("frac_mask") + ) + flat_frac_mask = None + if frac_mask is not None: + flat_frac_mask = _prepare_frac_mask(frac_mask, da_t, src_values, src_dim) + remapper = remappers[src_dim] + out_flat = remapper.apply(flat_src, frac_mask=flat_frac_mask) + + out_shape = src_values.shape[:-1] + (remapper._tgt_size,) + out_values = out_flat.reshape(out_shape) + coords = {dim: da.coords[dim] for dim in other_dims if dim in da.coords} + da_out = uxarray.core.dataarray.UxDataArray( + out_values, + dims=other_dims + [destination_dim], + coords=coords, + name=da.name, + attrs=da.attrs, + uxgrid=destination_grid, + ) + remapped_vars[var_name] = da_out + + ds_remapped = _construct_remapped_ds( + source, remapped_vars, destination_grid, remap_to + ) + return ds_remapped[name] if is_da else ds_remapped