From fdbd29f40098064559186698207c0b024b575a6a Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 17 Mar 2026 11:36:23 +0100 Subject: [PATCH 1/7] fix: legend visibility, style persistence, and axis ranges in combined figures Combined figures (overlay, add_secondary_y) had three issues: 1. Legends disappeared because Plotly Express sets showlegend=False on single-trace figures. Now unnamed traces get names derived from the source figure's y-axis title, and showlegend is fixed per legendgroup. 2. Colors and styles were lost during animation because frame traces carried PX defaults. Now marker, line, opacity and legend properties are propagated from fig.data into all animation frame traces. 3. Axis ranges were computed from fig.data only, so frames with different data ranges went off-screen during animation. Now global min/max is computed across all frames and set explicitly on the layout. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_figures.py | 139 +++++++++++++++++++++++++ xarray_plotly/figures.py | 212 +++++++++++++++++++++++++++++++++++---- 2 files changed, 332 insertions(+), 19 deletions(-) diff --git a/tests/test_figures.py b/tests/test_figures.py index dacf19d..dc97699 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -616,3 +616,142 @@ def test_secondary_not_modified(self) -> None: # Secondary traces should still use original yaxis assert secondary.data[0].yaxis == original_yaxis + + +class TestLegendVisibility: + """Tests that combined figures preserve legend visibility.""" + + def test_overlay_single_trace_figures_with_names(self) -> None: + """Overlay of named single-trace figures shows legend.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="a") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="b") + + fig1 = xpx(da1).line() + fig1.update_traces(name="Series A") + fig2 = xpx(da2).line() + fig2.update_traces(name="Series B") + + combined = overlay(fig1, fig2) + + assert combined.data[0].showlegend is True + assert combined.data[1].showlegend is True + + def test_overlay_unnamed_traces_get_yaxis_title(self) -> None: + """Overlay of unnamed traces derives names from y-axis titles.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure") + + fig1 = xpx(da1).line() + fig2 = xpx(da2).line() + + combined = overlay(fig1, fig2) + + # Names derived from y-axis titles (DataArray names) + assert combined.data[0].name == "Temperature" + assert combined.data[1].name == "Pressure" + assert combined.data[0].showlegend is True + assert combined.data[1].showlegend is True + + def test_overlay_same_name_disambiguated(self) -> None: + """Overlay of figures with same y-axis title gets numeric suffix.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="value") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="value") + + fig1 = xpx(da1).line() + fig2 = xpx(da2).line() + + combined = overlay(fig1, fig2) + + assert combined.data[0].name == "value (1)" + assert combined.data[1].name == "value (2)" + + def test_overlay_multi_trace_deduplicates_legend(self) -> None: + """Overlay of multi-trace figures deduplicates shared legendgroups.""" + da = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "cat"], + coords={"cat": ["A", "B", "C"]}, + ) + fig1 = xpx(da).area() + fig2 = xpx(da).line() + + combined = overlay(fig1, fig2) + + # First occurrence of each legendgroup should show, duplicates hidden + from collections import defaultdict + + groups: dict[str, list[bool]] = defaultdict(list) + for trace in combined.data: + lg = trace.legendgroup + groups[lg].append(trace.showlegend is True) + + for lg, flags in groups.items(): + assert flags.count(True) == 1, f"legendgroup {lg!r} has {flags.count(True)} visible" + + def test_add_secondary_y_single_trace_with_names(self) -> None: + """add_secondary_y of named single-trace figures shows legend.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="temp") + da2 = xr.DataArray([100, 200, 300], dims=["x"], name="precip") + + fig1 = xpx(da1).line() + fig1.update_traces(name="Temperature") + fig2 = xpx(da2).bar() + fig2.update_traces(name="Precipitation") + + combined = add_secondary_y(fig1, fig2) + + assert combined.data[0].showlegend is True + assert combined.data[1].showlegend is True + + def test_overlay_faceted_legendgroup_dedup(self) -> None: + """Faceted overlay keeps only one showlegend=True per legendgroup.""" + da = xr.DataArray( + np.random.rand(10, 2, 2), + dims=["x", "cat", "facet"], + coords={"cat": ["A", "B"], "facet": ["left", "right"]}, + ) + fig1 = xpx(da).area(facet_col="facet") + fig2 = xpx(da).line(facet_col="facet") + + combined = overlay(fig1, fig2) + + # Check each legendgroup has at least one showlegend=True + from collections import defaultdict + + groups: dict[str, list[bool]] = defaultdict(list) + for trace in combined.data: + lg = trace.legendgroup or "" + if lg: + groups[lg].append(trace.showlegend is True) + + for lg, flags in groups.items(): + assert any(flags), f"legendgroup {lg!r} has no showlegend=True trace" + + def test_overlay_animation_frames_preserve_style(self) -> None: + """Animation frame traces keep legend and color from fig.data.""" + da = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "time"], + coords={"time": [0, 1, 2]}, + name="Population", + ) + da_smooth = da.rolling(x=3, center=True).mean() + da_smooth.name = "Smoothed" + + fig1 = xpx(da).bar(animation_frame="time") + fig1.update_traces(marker={"color": "steelblue"}) + fig2 = xpx(da_smooth).line(animation_frame="time") + fig2.update_traces(line={"color": "red"}) + + combined = overlay(fig1, fig2) + + for frame in combined.frames: + for i, ft in enumerate(frame.data): + src = combined.data[i] + assert ft.name == src.name + assert ft.showlegend == src.showlegend + assert ft.legendgroup == src.legendgroup + # Bar trace should keep steelblue + assert frame.data[0].marker.color == "steelblue" + # Line trace should keep red + assert frame.data[1].line.color == "red" diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 6ed743d..06ec19d 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -13,6 +13,171 @@ import plotly.graph_objects as go +def _get_yaxis_title(fig: go.Figure) -> str: + """Extract the primary y-axis title text from a figure. + + Args: + fig: A Plotly figure. + + Returns: + The y-axis title text, or empty string if not set. + """ + try: + return fig.layout.yaxis.title.text or "" + except AttributeError: + return "" + + +def _ensure_legend_visibility( + combined: go.Figure, + source_figs: list[go.Figure], + trace_slices: list[slice], +) -> None: + """Fix legend visibility on a combined figure. + + Handles three problems that arise when combining Plotly Express figures: + + 1. **Unnamed traces** — PX sets ``name=""`` on single-trace (no color) + figures. We derive a name from each source figure's y-axis title. + 2. **Hidden named traces** — PX sets ``showlegend=False`` on single-trace + figures. We ensure at least one trace per ``legendgroup`` (or each + ungrouped named trace) has ``showlegend=True``. + 3. **Duplicate legend entries** — when two source figures share the same + ``legendgroup`` names, we deduplicate so only the first trace per + group shows in the legend. + + Args: + combined: The combined Plotly figure (mutated in place). + source_figs: The original source figures, in trace order. + trace_slices: Slices into ``combined.data`` for each source figure. + """ + from collections import defaultdict + + # --- Step 1: label unnamed traces from source y-axis titles ----------- + labels = [_get_yaxis_title(f) for f in source_figs] + + # If all labels are the same, disambiguate + unique_labels = {lb for lb in labels if lb} + if len(unique_labels) == 1: + labels = [f"{labels[0]} ({i + 1})" for i in range(len(labels))] + + for label, sl in zip(labels, trace_slices, strict=False): + if not label: + continue + for trace in combined.data[sl]: + if not getattr(trace, "name", None): + trace.name = label + trace.legendgroup = label + + # --- Step 2 & 3: fix showlegend per legendgroup ----------------------- + grouped: dict[str, list[Any]] = defaultdict(list) + ungrouped: list[Any] = [] + + for trace in combined.data: + lg = getattr(trace, "legendgroup", None) or "" + if lg: + grouped[lg].append(trace) + else: + ungrouped.append(trace) + + for traces in grouped.values(): + has_visible = False + for t in traces: + if has_visible: + # Deduplicate: only first keeps showlegend + t.showlegend = False + elif getattr(t, "name", None): + t.showlegend = True + has_visible = True + + # Ungrouped traces with a name should show in the legend + for trace in ungrouped: + if getattr(trace, "name", None): + trace.showlegend = True + + # --- Step 4: propagate style properties to animation frame traces ------ + # When Plotly animates, frame trace data overwrites fig.data properties. + # PX frame traces carry name="", showlegend=False and default colors, + # discarding any styling the user applied via update_traces() before + # combining. Propagate display properties from fig.data into every frame. + _STYLE_ATTRS = ("name", "legendgroup", "showlegend", "marker", "line", "opacity") + for frame in combined.frames or []: + for i, frame_trace in enumerate(frame.data): + if i < len(combined.data): + src = combined.data[i] + for attr in _STYLE_ATTRS: + src_val = getattr(src, attr, None) + if src_val is not None: + setattr(frame_trace, attr, src_val) + + +def _fix_animation_axis_ranges(fig: go.Figure) -> None: + """Set axis ranges to encompass data across all animation frames. + + Plotly.js computes autorange from ``fig.data`` only and does not + recalculate during animation. When different frames have very different + data ranges (e.g. population of Brazil vs China), values can go off-screen. + This function computes the global min/max for each axis across all frames + and sets explicit ranges on the layout. + + Only numeric axes are handled; categorical/date axes are left to autorange. + + Args: + fig: A Plotly figure with animation frames (mutated in place). + """ + import numpy as np + + if not fig.frames: + return + + from collections import defaultdict + + # Collect numeric y-values per axis across all traces (fig.data + frames) + y_by_axis: dict[str, list[float]] = defaultdict(list) + x_by_axis: dict[str, list[float]] = defaultdict(list) + + for trace in _iter_all_traces(fig): + yaxis = getattr(trace, "yaxis", None) or "y" + xaxis = getattr(trace, "xaxis", None) or "x" + + y = getattr(trace, "y", None) + if y is not None: + try: + arr = np.asarray(y, dtype=float) + finite = arr[np.isfinite(arr)] + if len(finite): + y_by_axis[yaxis].extend(finite.tolist()) + except (ValueError, TypeError): + pass # Non-numeric (categorical) — skip + + x = getattr(trace, "x", None) + if x is not None: + try: + arr = np.asarray(x, dtype=float) + finite = arr[np.isfinite(arr)] + if len(finite): + x_by_axis[xaxis].extend(finite.tolist()) + except (ValueError, TypeError): + pass + + # Apply ranges to layout + for axis_ref, values in y_by_axis.items(): + if not values: + continue + lo, hi = min(values), max(values) + pad = (hi - lo) * 0.05 or 1 # 5% padding + layout_prop = "yaxis" if axis_ref == "y" else f"yaxis{axis_ref[1:]}" + fig.layout[layout_prop].range = [lo - pad, hi + pad] + + for axis_ref, values in x_by_axis.items(): + if not values: + continue + lo, hi = min(values), max(values) + pad = (hi - lo) * 0.05 or 1 + layout_prop = "xaxis" if axis_ref == "x" else f"xaxis{axis_ref[1:]}" + fig.layout[layout_prop].range = [lo - pad, hi + pad] + + def _iter_all_traces(fig: go.Figure) -> Iterator[Any]: """Iterate over all traces in a figure, including animation frames. @@ -194,17 +359,11 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure: _validate_compatible_structure(base, overlay) _validate_animation_compatibility(base, overlay) - # Create new figure with base's layout - combined = go.Figure(layout=copy.deepcopy(base.layout)) - - # Add all traces from base - for trace in base.data: - combined.add_trace(copy.deepcopy(trace)) - - # Add all traces from overlays + # Create new figure with base's layout and all traces + all_traces = [copy.deepcopy(t) for t in base.data] for overlay in overlays: - for trace in overlay.data: - combined.add_trace(copy.deepcopy(trace)) + all_traces.extend(copy.deepcopy(t) for t in overlay.data) + combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout)) # Handle animation frames if base.frames: @@ -213,6 +372,17 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure: merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts) combined.frames = merged_frames + # Build trace slices for legend fix + source_figs = [base, *overlays] + slices: list[slice] = [] + offset = 0 + for fig in source_figs: + n = len(fig.data) + slices.append(slice(offset, offset + n)) + offset += n + + _ensure_legend_visibility(combined, source_figs, slices) + _fix_animation_axis_ranges(combined) return combined @@ -315,19 +485,15 @@ def add_secondary_y( rightmost_x = max(x_for_y.values(), key=lambda x: int(x[1:]) if x != "x" else 1) rightmost_primary_y = next(y for y, x in x_for_y.items() if x == rightmost_x) - # Create new figure with base's layout - combined = go.Figure(layout=copy.deepcopy(base.layout)) - - # Add all traces from base (primary y-axis) - for trace in base.data: - combined.add_trace(copy.deepcopy(trace)) - - # Add all traces from secondary, remapped to secondary y-axes + # Build all traces: base (primary) + secondary (remapped to secondary y-axes) + all_traces = [copy.deepcopy(t) for t in base.data] for trace in secondary.data: trace_copy = copy.deepcopy(trace) original_yaxis = getattr(trace_copy, "yaxis", None) or "y" trace_copy.yaxis = y_mapping[original_yaxis] - combined.add_trace(trace_copy) + all_traces.append(trace_copy) + + combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout)) # Get the rightmost secondary y-axis name for linking rightmost_secondary_y = y_mapping[rightmost_primary_y] @@ -368,6 +534,14 @@ def add_secondary_y( merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping) combined.frames = merged_frames + base_n = len(base.data) + sec_n = len(secondary.data) + _ensure_legend_visibility( + combined, + [base, secondary], + [slice(0, base_n), slice(base_n, base_n + sec_n)], + ) + _fix_animation_axis_ranges(combined) return combined From c1d985cc55d9f0880c8930ce1caf93639f6e3453 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 17 Mar 2026 12:26:20 +0100 Subject: [PATCH 2/7] feat: add subplots() for composing figures into a grid New function to arrange independent figures in a subplot grid: grid = subplots(fig1, fig2, fig3, cols=2) - Subplot titles auto-derived from figure title or y-axis label - Axis config (titles, tick format, type) copied from source figures - Validates: rejects faceted or animated figures (not supported) - Empty cells via go.Figure() Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_figures.py | 131 +++++++++++++++++++++++++++++++++- xarray_plotly/__init__.py | 2 + xarray_plotly/figures.py | 145 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 276 insertions(+), 2 deletions(-) diff --git a/tests/test_figures.py b/tests/test_figures.py index dc97699..3e9b2b3 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -1,4 +1,4 @@ -"""Tests for the figures module (overlay, add_secondary_y).""" +"""Tests for the figures module (overlay, add_secondary_y, subplots).""" from __future__ import annotations @@ -9,7 +9,7 @@ import pytest import xarray as xr -from xarray_plotly import add_secondary_y, overlay, xpx +from xarray_plotly import add_secondary_y, overlay, subplots, xpx class TestOverlayBasic: @@ -755,3 +755,130 @@ def test_overlay_animation_frames_preserve_style(self) -> None: assert frame.data[0].marker.color == "steelblue" # Line trace should keep red assert frame.data[1].line.color == "red" + + +class TestSubplotsBasic: + """Basic tests for subplots function.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + self.da2 = xr.DataArray([10, 20, 30], dims=["x"], name="Rainfall") + self.da3 = xr.DataArray([100, 200, 300], dims=["x"], name="Wind") + + def test_single_figure(self) -> None: + fig = xpx(self.da1).line() + grid = subplots(fig) + assert len(grid.data) == 1 + + def test_two_figures_one_column(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = xpx(self.da2).bar() + grid = subplots(fig1, fig2, cols=1) + assert len(grid.data) == 2 + # Should be on different y-axes (different rows) + assert grid.data[0].yaxis != grid.data[1].yaxis + + def test_two_figures_two_columns(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = xpx(self.da2).bar() + grid = subplots(fig1, fig2, cols=2) + assert len(grid.data) == 2 + assert grid.data[0].xaxis != grid.data[1].xaxis + + def test_three_figures_two_columns(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = xpx(self.da2).bar() + fig3 = xpx(self.da3).scatter() + grid = subplots(fig1, fig2, fig3, cols=2) + assert len(grid.data) == 3 + + def test_trace_count_preserved(self) -> None: + da_multi = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "cat"], + coords={"cat": ["A", "B", "C"]}, + ) + fig1 = xpx(da_multi).line() # 3 traces + fig2 = xpx(self.da1).bar() # 1 trace + grid = subplots(fig1, fig2, cols=2) + assert len(grid.data) == len(fig1.data) + len(fig2.data) + + def test_with_empty_figure(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = go.Figure() + grid = subplots(fig1, fig2, cols=2) + assert len(grid.data) == 1 + + +class TestSubplotsTitles: + """Tests for subplot title derivation.""" + + def test_titles_from_figure_title(self) -> None: + da = xr.DataArray([1, 2, 3], dims=["x"], name="val") + fig1 = xpx(da).line(title="My Title") + fig2 = xpx(da).bar(title="Other Title") + grid = subplots(fig1, fig2, cols=2) + titles = [ann.text for ann in grid.layout.annotations] + assert titles == ["My Title", "Other Title"] + + def test_titles_from_yaxis_label(self) -> None: + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure") + fig1 = xpx(da1).line() + fig2 = xpx(da2).line() + grid = subplots(fig1, fig2, cols=2) + titles = [ann.text for ann in grid.layout.annotations] + assert titles == ["Temperature", "Pressure"] + + def test_titles_fallback_empty(self) -> None: + grid = subplots(go.Figure(), go.Figure(), cols=2) + # make_subplots omits annotations for empty titles + titles = [ann.text for ann in grid.layout.annotations] + assert titles == [] + + +class TestSubplotsAxisConfig: + """Tests for axis configuration copying.""" + + def test_axis_titles_copied(self) -> None: + da = xr.DataArray([1, 2, 3], dims=["time"], name="Temperature") + fig = xpx(da).line() + grid = subplots(fig) + assert grid.layout.yaxis.title.text == "Temperature" + assert grid.layout.xaxis.title.text == "time" + + +class TestSubplotsValidation: + """Tests for subplots input validation.""" + + def test_empty_raises(self) -> None: + with pytest.raises(ValueError, match="At least one figure"): + subplots() + + def test_invalid_cols_raises(self) -> None: + with pytest.raises(ValueError, match="cols must be >= 1"): + subplots(go.Figure(), cols=0) + + def test_faceted_figure_raises(self) -> None: + da = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "facet"], + coords={"facet": ["A", "B", "C"]}, + ) + fig = xpx(da).line(facet_col="facet") + with pytest.raises(ValueError, match="internal subplots"): + subplots(fig) + + def test_animated_figure_raises(self) -> None: + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "time"]) + fig = xpx(da).line(animation_frame="time") + with pytest.raises(ValueError, match="animation frames"): + subplots(fig) + + def test_source_not_modified(self) -> None: + da = xr.DataArray([1, 2, 3], dims=["x"], name="val") + fig = xpx(da).line() + original_count = len(fig.data) + _ = subplots(fig, fig, cols=2) + assert len(fig.data) == original_count diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index 7cefbb5..c526173 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -56,6 +56,7 @@ from xarray_plotly.figures import ( add_secondary_y, overlay, + subplots, update_traces, ) @@ -66,6 +67,7 @@ "auto", "config", "overlay", + "subplots", "update_traces", "xpx", ] diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 06ec19d..b18c4c6 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -600,6 +600,151 @@ def _merge_secondary_y_frames( return merged_frames +def _get_figure_title(fig: go.Figure) -> str: + """Extract a display title from a figure for use as a subplot title. + + Checks, in order: the figure's title, then the y-axis title. + + Args: + fig: A Plotly figure. + + Returns: + A title string, or empty string if nothing is set. + """ + try: + title = fig.layout.title.text + if title: + return title + except AttributeError: + pass + return _get_yaxis_title(fig) + + +def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure: + """Arrange multiple figures into a subplot grid. + + Creates a new figure with each input figure placed in its own cell. + Subplot titles are derived from each figure's title or y-axis label. + + Args: + *figs: One or more Plotly figures to arrange. + cols: Number of columns in the grid. Rows are computed automatically. + + Returns: + A new figure with subplot grid. + + Raises: + ValueError: If no figures are provided, cols < 1, or a figure has + internal subplots (facets) or animation frames. + + Example: + >>> import numpy as np + >>> import xarray as xr + >>> from xarray_plotly import xpx, subplots + >>> + >>> temp = xr.DataArray([20, 22, 25], dims=["time"], name="Temperature") + >>> rain = xr.DataArray([0, 5, 12], dims=["time"], name="Rainfall") + >>> fig1 = xpx(temp).line() + >>> fig2 = xpx(rain).bar() + >>> grid = subplots(fig1, fig2, cols=2) + """ + import math + + from plotly.subplots import make_subplots + + if not figs: + raise ValueError("At least one figure is required.") + if cols < 1: + raise ValueError(f"cols must be >= 1, got {cols}.") + + # Validate inputs + for i, fig in enumerate(figs): + axes = _get_subplot_axes(fig) + if len(axes) > 1: + raise ValueError( + f"Figure at position {i} has internal subplots (facets). " + "Use single-panel figures with subplots()." + ) + if fig.frames: + raise ValueError( + f"Figure at position {i} has animation frames. " + "Animated figures are not supported in subplots()." + ) + + rows = math.ceil(len(figs) / cols) + + # Derive subplot titles + titles = [_get_figure_title(f) for f in figs] + # Pad for empty trailing cells + titles.extend("" for _ in range(rows * cols - len(figs))) + + grid = make_subplots(rows=rows, cols=cols, subplot_titles=titles) + + # Add traces from each figure to the correct cell + for i, fig in enumerate(figs): + row = i // cols + 1 + col = i % cols + 1 + + for trace in fig.data: + grid.add_trace(copy.deepcopy(trace), row=row, col=col) + + # Copy axis config from source figure to target cell + _copy_axis_config(fig, grid, row, col) + + return grid + + +# Axis properties safe to copy between figures (display-only, not structural). +_AXIS_PROPS_TO_COPY = ( + "title", + "type", + "tickformat", + "ticksuffix", + "tickprefix", + "dtick", + "tick0", + "nticks", + "showgrid", + "gridcolor", + "gridwidth", + "autorange", + "range", + "zeroline", + "zerolinecolor", + "zerolinewidth", +) + + +def _copy_axis_config(src: go.Figure, grid: go.Figure, row: int, col: int) -> None: + """Copy display-related axis properties from a source figure to a grid cell. + + Args: + src: Source figure whose axis config to copy. + grid: Target subplot grid figure. + row: Target row (1-indexed). + col: Target column (1-indexed). + """ + # Get the xaxis/yaxis objects for the target cell + xref, yref = grid.get_subplot(row, col) + + # Convert plotly axis objects to layout property names + # xref.plotly_name is e.g. "xaxis" or "xaxis2" + x_layout_key = xref.plotly_name + y_layout_key = yref.plotly_name + + src_xaxis = src.layout.xaxis or {} + src_yaxis = src.layout.yaxis or {} + + for prop in _AXIS_PROPS_TO_COPY: + xval = getattr(src_xaxis, prop, None) + if xval is not None: + grid.layout[x_layout_key][prop] = xval + + yval = getattr(src_yaxis, prop, None) + if yval is not None: + grid.layout[y_layout_key][prop] = yval + + def update_traces( fig: go.Figure, selector: dict[str, Any] | None = None, **kwargs: Any ) -> go.Figure: From dda84c778f0bf233387bd12f5466e3729021bfee Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 17 Mar 2026 12:28:49 +0100 Subject: [PATCH 3/7] docs: add subplots examples to combining notebook Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/examples/combining.ipynb | 163 ++++++++++++++++++++++++++++++---- 1 file changed, 147 insertions(+), 16 deletions(-) diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index 6755c20..36cc2fb 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -10,7 +10,9 @@ "xarray-plotly provides helper functions to combine multiple figures:\n", "\n", "- **`overlay`**: Overlay traces on the same axes\n", - "- **`add_secondary_y`**: Plot with two independent y-axes" + "- **`add_secondary_y`**: Plot with two independent y-axes\n", + "- **`subplots`**: Arrange independent figures in a grid\n", + "- **`slider_to_dropdown`**: Convert animation slider to a dropdown menu" ] }, { @@ -23,7 +25,7 @@ "import plotly.express as px\n", "import xarray as xr\n", "\n", - "from xarray_plotly import add_secondary_y, config, overlay, xpx\n", + "from xarray_plotly import add_secondary_y, config, overlay, subplots, xpx\n", "\n", "config.notebook()" ] @@ -440,6 +442,134 @@ "cell_type": "markdown", "id": "27", "metadata": {}, + "source": [ + "## subplots\n", + "\n", + "Arrange independent figures side-by-side in a grid. Each figure gets its own\n", + "subplot cell with axes and title automatically derived from the source figure." + ] + }, + { + "cell_type": "markdown", + "id": "28", + "metadata": {}, + "source": [ + "### Different Variables Side by Side" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [], + "source": [ + "# One figure per variable, arranged in a row\n", + "us_pop = population.sel(country=\"United States\")\n", + "us_gdp = gdp_per_capita.sel(country=\"United States\")\n", + "us_life = life_expectancy.sel(country=\"United States\")\n", + "\n", + "pop_fig = xpx(us_pop).bar(title=\"Population\")\n", + "gdp_fig = xpx(us_gdp).line(title=\"GDP per Capita\")\n", + "life_fig = xpx(us_life).line(title=\"Life Expectancy\")\n", + "\n", + "grid = subplots(pop_fig, gdp_fig, life_fig, cols=3)\n", + "grid.update_layout(title=\"United States Overview\", height=350, showlegend=False)\n", + "grid" + ] + }, + { + "cell_type": "markdown", + "id": "30", + "metadata": {}, + "source": [ + "### 2x2 Grid\n", + "\n", + "Use `cols=2` and pass four figures for a 2x2 layout." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "# One subplot per country\n", + "fig_us = xpx(population.sel(country=\"United States\")).bar(title=\"United States\")\n", + "fig_cn = xpx(population.sel(country=\"China\")).bar(title=\"China\")\n", + "fig_de = xpx(population.sel(country=\"Germany\")).bar(title=\"Germany\")\n", + "fig_br = xpx(population.sel(country=\"Brazil\")).bar(title=\"Brazil\")\n", + "\n", + "grid = subplots(fig_us, fig_cn, fig_de, fig_br, cols=2)\n", + "grid.update_layout(height=500, showlegend=False)\n", + "grid" + ] + }, + { + "cell_type": "markdown", + "id": "32", + "metadata": {}, + "source": [ + "### Mixed Chart Types\n", + "\n", + "Each cell can use a different chart type. Subplot titles fall back to the\n", + "y-axis label when no explicit title is set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [], + "source": [ + "# No explicit title — subplot titles come from the y-axis label (DataArray name)\n", + "pop_bar = xpx(us_pop).bar()\n", + "gdp_line = xpx(us_gdp).line()\n", + "life_scatter = xpx(us_life).scatter()\n", + "\n", + "grid = subplots(pop_bar, gdp_line, life_scatter, cols=3)\n", + "grid.update_layout(height=350, showlegend=False)\n", + "grid" + ] + }, + { + "cell_type": "markdown", + "id": "34", + "metadata": {}, + "source": [ + "### Limitations\n", + "\n", + "`subplots` requires single-panel figures — faceted and animated figures are not supported." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "# Faceted figure → rejected\n", + "faceted = xpx(population).line(facet_col=\"country\")\n", + "try:\n", + " subplots(faceted)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")\n", + "\n", + "# Animated figure → rejected\n", + "animated = xpx(population).bar(animation_frame=\"country\")\n", + "try:\n", + " subplots(animated)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "36", + "metadata": {}, "source": [ "---\n", "\n", @@ -450,7 +580,7 @@ }, { "cell_type": "markdown", - "id": "28", + "id": "37", "metadata": {}, "source": [ "### overlay: Mismatched Facet Structure\n", @@ -461,7 +591,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -479,7 +609,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "39", "metadata": {}, "source": [ "### overlay: Animated Overlay on Static Base\n", @@ -490,7 +620,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +638,7 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "41", "metadata": {}, "source": [ "### overlay: Mismatched Animation Frames\n", @@ -519,7 +649,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "42", "metadata": {}, "outputs": [], "source": [ @@ -535,7 +665,7 @@ }, { "cell_type": "markdown", - "id": "34", + "id": "43", "metadata": {}, "source": [ "### add_secondary_y: Mismatched Facet Structure\n", @@ -546,7 +676,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "44", "metadata": {}, "outputs": [], "source": [ @@ -564,7 +694,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "45", "metadata": {}, "source": [ "### add_secondary_y: Animated Secondary on Static Base\n", @@ -575,7 +705,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -593,7 +723,7 @@ }, { "cell_type": "markdown", - "id": "38", + "id": "47", "metadata": {}, "source": [ "### add_secondary_y: Mismatched Animation Frames" @@ -602,7 +732,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "48", "metadata": {}, "outputs": [], "source": [ @@ -618,7 +748,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "49", "metadata": {}, "source": [ "## Summary\n", @@ -626,7 +756,8 @@ "| Function | Facets | Animation | Static + Animated |\n", "|----------|--------|-----------|-------------------|\n", "| `overlay` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", - "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |" + "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |\n", + "| `subplots` | No (single-panel only) | No | N/A |" ] } ], From a9d624747c23e6916441cf040572ece67931245b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 17 Mar 2026 14:15:30 +0100 Subject: [PATCH 4/7] feat: support faceted figures in subplots() Rewrote subplots() to use manual axis domain management instead of make_subplots. Each figure's internal axes are remapped with scaled domains to fit within the grid cell, so faceted figures now work. Updated notebook with faceted subplots example. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/examples/combining.ipynb | 24 ++-- tests/test_figures.py | 18 ++- xarray_plotly/figures.py | 211 ++++++++++++++++++++++++++-------- 3 files changed, 187 insertions(+), 66 deletions(-) diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index 36cc2fb..12ceb81 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -539,9 +539,9 @@ "id": "34", "metadata": {}, "source": [ - "### Limitations\n", + "### With Facets\n", "\n", - "`subplots` requires single-panel figures — faceted and animated figures are not supported." + "Faceted figures can be composed — each figure's internal subplots are remapped into the grid cell." ] }, { @@ -551,19 +551,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Faceted figure → rejected\n", - "faceted = xpx(population).line(facet_col=\"country\")\n", - "try:\n", - " subplots(faceted)\n", - "except ValueError as e:\n", - " print(f\"ValueError: {e}\")\n", + "# Faceted bar on top, faceted line below\n", + "pop_faceted = xpx(population).bar(facet_col=\"country\")\n", + "gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n", "\n", - "# Animated figure → rejected\n", - "animated = xpx(population).bar(animation_frame=\"country\")\n", - "try:\n", - " subplots(animated)\n", - "except ValueError as e:\n", - " print(f\"ValueError: {e}\")" + "grid = subplots(pop_faceted, gdp_faceted, cols=1)\n", + "grid.update_layout(height=600, showlegend=False)\n", + "grid" ] }, { @@ -757,7 +751,7 @@ "|----------|--------|-----------|-------------------|\n", "| `overlay` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |\n", - "| `subplots` | No (single-panel only) | No | N/A |" + "| `subplots` | Yes (remapped into cells) | No | N/A |" ] } ], diff --git a/tests/test_figures.py b/tests/test_figures.py index 3e9b2b3..529d101 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -820,7 +820,7 @@ def test_titles_from_figure_title(self) -> None: fig2 = xpx(da).bar(title="Other Title") grid = subplots(fig1, fig2, cols=2) titles = [ann.text for ann in grid.layout.annotations] - assert titles == ["My Title", "Other Title"] + assert titles == ["My Title", "Other Title"] def test_titles_from_yaxis_label(self) -> None: da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") @@ -829,7 +829,7 @@ def test_titles_from_yaxis_label(self) -> None: fig2 = xpx(da2).line() grid = subplots(fig1, fig2, cols=2) titles = [ann.text for ann in grid.layout.annotations] - assert titles == ["Temperature", "Pressure"] + assert titles == ["Temperature", "Pressure"] def test_titles_fallback_empty(self) -> None: grid = subplots(go.Figure(), go.Figure(), cols=2) @@ -860,15 +860,21 @@ def test_invalid_cols_raises(self) -> None: with pytest.raises(ValueError, match="cols must be >= 1"): subplots(go.Figure(), cols=0) - def test_faceted_figure_raises(self) -> None: + def test_faceted_figures_stacked(self) -> None: + """Faceted figures can be stacked in a subplot grid.""" da = xr.DataArray( np.random.rand(10, 3), dims=["x", "facet"], coords={"facet": ["A", "B", "C"]}, ) - fig = xpx(da).line(facet_col="facet") - with pytest.raises(ValueError, match="internal subplots"): - subplots(fig) + fig1 = xpx(da).bar(facet_col="facet") + fig2 = xpx(da).line(facet_col="facet") + grid = subplots(fig1, fig2, cols=1) + # 3 bar traces + 3 line traces + assert len(grid.data) == 6 + # All traces should have unique axis assignments + axes = {(t.xaxis, t.yaxis) for t in grid.data} + assert len(axes) == 6 def test_animated_figure_raises(self) -> None: da = xr.DataArray(np.random.rand(10, 3), dims=["x", "time"]) diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index b18c4c6..512b17e 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -624,7 +624,9 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure: """Arrange multiple figures into a subplot grid. Creates a new figure with each input figure placed in its own cell. - Subplot titles are derived from each figure's title or y-axis label. + Figures may contain internal subplots (facets) — their axes are remapped + to fit within the grid cell. Subplot titles are derived from each + figure's title or y-axis label. Args: *figs: One or more Plotly figures to arrange. @@ -635,7 +637,7 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure: Raises: ValueError: If no figures are provided, cols < 1, or a figure has - internal subplots (facets) or animation frames. + animation frames. Example: >>> import numpy as np @@ -650,21 +652,14 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure: """ import math - from plotly.subplots import make_subplots + import plotly.graph_objects as go if not figs: raise ValueError("At least one figure is required.") if cols < 1: raise ValueError(f"cols must be >= 1, got {cols}.") - # Validate inputs for i, fig in enumerate(figs): - axes = _get_subplot_axes(fig) - if len(axes) > 1: - raise ValueError( - f"Figure at position {i} has internal subplots (facets). " - "Use single-panel figures with subplots()." - ) if fig.frames: raise ValueError( f"Figure at position {i} has animation frames. " @@ -672,26 +667,57 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure: ) rows = math.ceil(len(figs) / cols) + combined = go.Figure() - # Derive subplot titles - titles = [_get_figure_title(f) for f in figs] - # Pad for empty trailing cells - titles.extend("" for _ in range(rows * cols - len(figs))) + # Grid spacing + h_gap = 0.05 + v_gap = 0.08 + cell_w = (1.0 - h_gap * (cols - 1)) / cols + cell_h = (1.0 - v_gap * (rows - 1)) / rows - grid = make_subplots(rows=rows, cols=cols, subplot_titles=titles) + next_x_num = 1 + next_y_num = 1 - # Add traces from each figure to the correct cell for i, fig in enumerate(figs): - row = i // cols + 1 - col = i % cols + 1 + row = i // cols # 0-indexed, top to bottom + col = i % cols + + # Cell boundaries (clamped to [0, 1]) + cell_x0 = max(0.0, col * (cell_w + h_gap)) + cell_x1 = min(1.0, cell_x0 + cell_w) + cell_y1 = min(1.0, 1.0 - row * (cell_h + v_gap)) # top-down + cell_y0 = max(0.0, cell_y1 - cell_h) + + # Build axis remapping: old axis ref → new axis ref + axis_map, next_x_num, next_y_num = _remap_figure_axes( + fig, combined, next_x_num, next_y_num, cell_x0, cell_x1, cell_y0, cell_y1 + ) + # Add traces with remapped axis refs for trace in fig.data: - grid.add_trace(copy.deepcopy(trace), row=row, col=col) - - # Copy axis config from source figure to target cell - _copy_axis_config(fig, grid, row, col) + tc = copy.deepcopy(trace) + old_x = getattr(tc, "xaxis", None) or "x" + old_y = getattr(tc, "yaxis", None) or "y" + tc.xaxis = axis_map[old_x]["new_x"] + tc.yaxis = axis_map[old_y]["new_y"] + combined.add_trace(tc) + + # Add subplot title as annotation + title = _get_figure_title(fig) + if title: + combined.add_annotation( + text=f"{title}", + x=(cell_x0 + cell_x1) / 2, + y=cell_y1, + xref="paper", + yref="paper", + xanchor="center", + yanchor="bottom", + showarrow=False, + font={"size": 14}, + ) - return grid + return combined # Axis properties safe to copy between figures (display-only, not structural). @@ -712,37 +738,132 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure: "zeroline", "zerolinecolor", "zerolinewidth", + "showticklabels", ) -def _copy_axis_config(src: go.Figure, grid: go.Figure, row: int, col: int) -> None: - """Copy display-related axis properties from a source figure to a grid cell. +def _axis_layout_key(ref: str) -> str: + """Convert axis reference to layout property name. - Args: - src: Source figure whose axis config to copy. - grid: Target subplot grid figure. - row: Target row (1-indexed). - col: Target column (1-indexed). + ``"x"`` → ``"xaxis"``, ``"x2"`` → ``"xaxis2"``, + ``"y"`` → ``"yaxis"``, ``"y3"`` → ``"yaxis3"``. """ - # Get the xaxis/yaxis objects for the target cell - xref, yref = grid.get_subplot(row, col) + if ref in ("x", "y"): + return f"{ref}axis" + prefix = ref[0] # "x" or "y" + num = ref[1:] + return f"{prefix}axis{num}" - # Convert plotly axis objects to layout property names - # xref.plotly_name is e.g. "xaxis" or "xaxis2" - x_layout_key = xref.plotly_name - y_layout_key = yref.plotly_name - src_xaxis = src.layout.xaxis or {} - src_yaxis = src.layout.yaxis or {} +def _new_axis_ref(prefix: str, num: int) -> str: + """Build an axis reference string. ``_new_axis_ref("x", 1)`` → ``"x"``, ``("x", 3)`` → ``"x3"``.""" + return prefix if num == 1 else f"{prefix}{num}" - for prop in _AXIS_PROPS_TO_COPY: - xval = getattr(src_xaxis, prop, None) - if xval is not None: - grid.layout[x_layout_key][prop] = xval - yval = getattr(src_yaxis, prop, None) - if yval is not None: - grid.layout[y_layout_key][prop] = yval +def _remap_figure_axes( + fig: go.Figure, + combined: go.Figure, + next_x_num: int, + next_y_num: int, + cell_x0: float, + cell_x1: float, + cell_y0: float, + cell_y1: float, +) -> tuple[dict[str, dict[str, str]], int, int]: + """Remap a figure's axes into a grid cell, adding axis configs to the combined layout. + + Args: + fig: Source figure. + combined: Target combined figure (mutated — axis configs added to layout). + next_x_num: Next available x-axis number. + next_y_num: Next available y-axis number. + cell_x0, cell_x1: Horizontal cell bounds in paper coordinates. + cell_y0, cell_y1: Vertical cell bounds in paper coordinates. + + Returns: + Tuple of (axis_map, next_x_num, next_y_num). + axis_map maps old axis refs to ``{"new_x": ...}`` or ``{"new_y": ...}``. + """ + cell_w = cell_x1 - cell_x0 + cell_h = cell_y1 - cell_y0 + src_layout = fig.layout.to_plotly_json() + + x_remap: dict[str, str] = {} + y_remap: dict[str, str] = {} + + # Get all unique axis refs + x_refs: set[str] = set() + y_refs: set[str] = set() + for trace in fig.data: + x_refs.add(getattr(trace, "xaxis", None) or "x") + y_refs.add(getattr(trace, "yaxis", None) or "y") + + # Remap x-axes + for old_xref in sorted(x_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1): + new_xref = _new_axis_ref("x", next_x_num) + x_remap[old_xref] = new_xref + + src_config = src_layout.get(_axis_layout_key(old_xref), {}) + src_domain = src_config.get("domain", [0.0, 1.0]) + new_domain = [ + max(0.0, cell_x0 + src_domain[0] * cell_w), + min(1.0, cell_x0 + src_domain[1] * cell_w), + ] + + new_config: dict[str, Any] = {"domain": new_domain} + for prop in _AXIS_PROPS_TO_COPY: + if prop in src_config: + new_config[prop] = src_config[prop] + + combined.layout[_axis_layout_key(new_xref)] = new_config + next_x_num += 1 + + # Remap y-axes + for old_yref in sorted(y_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1): + new_yref = _new_axis_ref("y", next_y_num) + y_remap[old_yref] = new_yref + + src_config = src_layout.get(_axis_layout_key(old_yref), {}) + src_domain = src_config.get("domain", [0.0, 1.0]) + new_domain = [ + max(0.0, cell_y0 + src_domain[0] * cell_h), + min(1.0, cell_y0 + src_domain[1] * cell_h), + ] + + new_config = {"domain": new_domain} + for prop in _AXIS_PROPS_TO_COPY: + if prop in src_config: + new_config[prop] = src_config[prop] + + combined.layout[_axis_layout_key(new_yref)] = new_config + next_y_num += 1 + + # Set anchors between paired axes + for trace in fig.data: + old_x = getattr(trace, "xaxis", None) or "x" + old_y = getattr(trace, "yaxis", None) or "y" + combined.layout[_axis_layout_key(x_remap[old_x])]["anchor"] = y_remap[old_y] + combined.layout[_axis_layout_key(y_remap[old_y])]["anchor"] = x_remap[old_x] + + # Propagate matches relationships + for old_ref, new_ref in x_remap.items(): + src_config = src_layout.get(_axis_layout_key(old_ref), {}) + if "matches" in src_config and src_config["matches"] in x_remap: + combined.layout[_axis_layout_key(new_ref)]["matches"] = x_remap[src_config["matches"]] + + for old_ref, new_ref in y_remap.items(): + src_config = src_layout.get(_axis_layout_key(old_ref), {}) + if "matches" in src_config and src_config["matches"] in y_remap: + combined.layout[_axis_layout_key(new_ref)]["matches"] = y_remap[src_config["matches"]] + + # Build combined return mapping + result: dict[str, dict[str, str]] = {} + for old_x, new_x in x_remap.items(): + result[old_x] = {"new_x": new_x} + for old_y, new_y in y_remap.items(): + result[old_y] = {"new_y": new_y} + + return result, next_x_num, next_y_num def update_traces( From 0cb6e6ddc2fba3ba607bdb52919397b8187806f0 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:02:41 +0100 Subject: [PATCH 5/7] =?UTF-8?q?fix:=20review=20findings=20=E2=80=94=20date?= =?UTF-8?q?time=20axis,=20bar=20baseline,=20type=20narrowing,=20stale=20bu?= =?UTF-8?q?llet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Skip datetime64/timedelta64 axes in _fix_animation_axis_ranges to prevent float epoch corruption; leave them on autorange - Include zero in axis range for bar traces so bars grow from baseline - Use isinstance(str) check in _get_figure_title for mypy type narrowing - Remove stale slider_to_dropdown bullet from notebook intro Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/examples/combining.ipynb | 3 +-- tests/test_figures.py | 34 +++++++++++++++++++++++++ xarray_plotly/figures.py | 47 ++++++++++++++++++++++++----------- 3 files changed, 67 insertions(+), 17 deletions(-) diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index 12ceb81..a72d752 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -11,8 +11,7 @@ "\n", "- **`overlay`**: Overlay traces on the same axes\n", "- **`add_secondary_y`**: Plot with two independent y-axes\n", - "- **`subplots`**: Arrange independent figures in a grid\n", - "- **`slider_to_dropdown`**: Convert animation slider to a dropdown menu" + "- **`subplots`**: Arrange independent figures in a grid" ] }, { diff --git a/tests/test_figures.py b/tests/test_figures.py index 529d101..26143ea 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -757,6 +757,40 @@ def test_overlay_animation_frames_preserve_style(self) -> None: assert frame.data[1].line.color == "red" +class TestAnimationAxisRanges: + """Tests for _fix_animation_axis_ranges.""" + + def test_datetime_x_axis_not_corrupted(self) -> None: + """datetime64 x-axis should be left on autorange, not cast to float epochs.""" + dates = np.array(["2020-01-01", "2020-06-01", "2021-01-01"], dtype="datetime64[ns]") + da = xr.DataArray( + np.random.rand(3, 2), + dims=["date", "cat"], + coords={"date": dates, "cat": ["A", "B"]}, + name="value", + ) + fig1 = xpx(da).line(animation_frame="cat") + fig2 = xpx(da).scatter(animation_frame="cat") + combined = overlay(fig1, fig2) + + # x-axis range should NOT be set (dates left to autorange) + assert combined.layout.xaxis.range is None + + def test_bar_zero_baseline(self) -> None: + """Bar chart y-axis range should include zero.""" + da = xr.DataArray( + np.array([[100, 200], [150, 250]]), + dims=["x", "frame"], + name="val", + ) + fig = xpx(da).bar(animation_frame="frame") + # After overlay (which triggers _fix_animation_axis_ranges) + combined = overlay(fig, xpx(da).line(animation_frame="frame")) + + lo, hi = combined.layout.yaxis.range + assert lo <= 0, f"Bar y-axis range should include 0, got lo={lo}" + + class TestSubplotsBasic: """Basic tests for subplots function.""" diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 512b17e..3688ede 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -136,35 +136,49 @@ def _fix_animation_axis_ranges(fig: go.Figure) -> None: y_by_axis: dict[str, list[float]] = defaultdict(list) x_by_axis: dict[str, list[float]] = defaultdict(list) + # Track which axes have bar traces (for zero-baseline clamping) + y_has_vbar: set[str] = set() # vertical bars → y-axis includes 0 + x_has_hbar: set[str] = set() # horizontal bars → x-axis includes 0 + for trace in _iter_all_traces(fig): yaxis = getattr(trace, "yaxis", None) or "y" xaxis = getattr(trace, "xaxis", None) or "x" - y = getattr(trace, "y", None) - if y is not None: + # Track bar orientations + if getattr(trace, "type", None) == "bar": + orientation = getattr(trace, "orientation", None) or "v" + if orientation == "h": + x_has_hbar.add(xaxis) + else: + y_has_vbar.add(yaxis) + + for data_attr, axis_ref, by_axis in [ + ("y", yaxis, y_by_axis), + ("x", xaxis, x_by_axis), + ]: + vals = getattr(trace, data_attr, None) + if vals is None: + continue + arr = np.asarray(vals) + # Skip datetime/timedelta — leave those axes on autorange + if np.issubdtype(arr.dtype, np.datetime64) or np.issubdtype(arr.dtype, np.timedelta64): + continue try: - arr = np.asarray(y, dtype=float) + arr = arr.astype(float) finite = arr[np.isfinite(arr)] if len(finite): - y_by_axis[yaxis].extend(finite.tolist()) + by_axis[axis_ref].extend(finite.tolist()) except (ValueError, TypeError): pass # Non-numeric (categorical) — skip - x = getattr(trace, "x", None) - if x is not None: - try: - arr = np.asarray(x, dtype=float) - finite = arr[np.isfinite(arr)] - if len(finite): - x_by_axis[xaxis].extend(finite.tolist()) - except (ValueError, TypeError): - pass - # Apply ranges to layout for axis_ref, values in y_by_axis.items(): if not values: continue lo, hi = min(values), max(values) + if axis_ref in y_has_vbar: + lo = min(lo, 0.0) + hi = max(hi, 0.0) pad = (hi - lo) * 0.05 or 1 # 5% padding layout_prop = "yaxis" if axis_ref == "y" else f"yaxis{axis_ref[1:]}" fig.layout[layout_prop].range = [lo - pad, hi + pad] @@ -173,6 +187,9 @@ def _fix_animation_axis_ranges(fig: go.Figure) -> None: if not values: continue lo, hi = min(values), max(values) + if axis_ref in x_has_hbar: + lo = min(lo, 0.0) + hi = max(hi, 0.0) pad = (hi - lo) * 0.05 or 1 layout_prop = "xaxis" if axis_ref == "x" else f"xaxis{axis_ref[1:]}" fig.layout[layout_prop].range = [lo - pad, hi + pad] @@ -613,7 +630,7 @@ def _get_figure_title(fig: go.Figure) -> str: """ try: title = fig.layout.title.text - if title: + if isinstance(title, str) and title: return title except AttributeError: pass From 83ae5314418746b4eeeb7943829f99402cf62fc1 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:33:08 +0100 Subject: [PATCH 6/7] fix: disambiguation condition and stale comment - Only disambiguate labels when all are non-empty and identical, preventing spurious suffixes when mixing named and unnamed figures - Update stale make_subplots comment in test Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_figures.py | 2 +- xarray_plotly/figures.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_figures.py b/tests/test_figures.py index 26143ea..9639285 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -867,7 +867,7 @@ def test_titles_from_yaxis_label(self) -> None: def test_titles_fallback_empty(self) -> None: grid = subplots(go.Figure(), go.Figure(), cols=2) - # make_subplots omits annotations for empty titles + # No annotations are created for empty titles titles = [ann.text for ann in grid.layout.annotations] assert titles == [] diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 3688ede..8a4bf00 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -56,9 +56,9 @@ def _ensure_legend_visibility( # --- Step 1: label unnamed traces from source y-axis titles ----------- labels = [_get_yaxis_title(f) for f in source_figs] - # If all labels are the same, disambiguate + # If all labels are non-empty and identical, disambiguate unique_labels = {lb for lb in labels if lb} - if len(unique_labels) == 1: + if len(unique_labels) == 1 and all(lb for lb in labels): labels = [f"{labels[0]} ({i + 1})" for i in range(len(labels))] for label, sl in zip(labels, trace_slices, strict=False): From f06987e9e9ba5311bb67367eb056725fa31709bb Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:39:22 +0100 Subject: [PATCH 7/7] fix: unused variable lint error in test Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_figures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_figures.py b/tests/test_figures.py index 9639285..254a3bd 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -787,7 +787,7 @@ def test_bar_zero_baseline(self) -> None: # After overlay (which triggers _fix_animation_axis_ranges) combined = overlay(fig, xpx(da).line(animation_frame="frame")) - lo, hi = combined.layout.yaxis.range + lo, _hi = combined.layout.yaxis.range assert lo <= 0, f"Bar y-axis range should include 0, got lo={lo}"