diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index 6755c20..a72d752 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -10,7 +10,8 @@ "xarray-plotly provides helper functions to combine multiple figures:\n", "\n", "- **`overlay`**: Overlay traces on the same axes\n", - "- **`add_secondary_y`**: Plot with two independent y-axes" + "- **`add_secondary_y`**: Plot with two independent y-axes\n", + "- **`subplots`**: Arrange independent figures in a grid" ] }, { @@ -23,7 +24,7 @@ "import plotly.express as px\n", "import xarray as xr\n", "\n", - "from xarray_plotly import add_secondary_y, config, overlay, xpx\n", + "from xarray_plotly import add_secondary_y, config, overlay, subplots, xpx\n", "\n", "config.notebook()" ] @@ -440,6 +441,128 @@ "cell_type": "markdown", "id": "27", "metadata": {}, + "source": [ + "## subplots\n", + "\n", + "Arrange independent figures side-by-side in a grid. Each figure gets its own\n", + "subplot cell with axes and title automatically derived from the source figure." + ] + }, + { + "cell_type": "markdown", + "id": "28", + "metadata": {}, + "source": [ + "### Different Variables Side by Side" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [], + "source": [ + "# One figure per variable, arranged in a row\n", + "us_pop = population.sel(country=\"United States\")\n", + "us_gdp = gdp_per_capita.sel(country=\"United States\")\n", + "us_life = life_expectancy.sel(country=\"United States\")\n", + "\n", + "pop_fig = xpx(us_pop).bar(title=\"Population\")\n", + "gdp_fig = xpx(us_gdp).line(title=\"GDP per Capita\")\n", + "life_fig = xpx(us_life).line(title=\"Life Expectancy\")\n", + "\n", + "grid = subplots(pop_fig, gdp_fig, life_fig, cols=3)\n", + "grid.update_layout(title=\"United States Overview\", height=350, showlegend=False)\n", + "grid" + ] + }, + { + "cell_type": "markdown", + "id": "30", + "metadata": {}, + "source": [ + "### 2x2 Grid\n", + "\n", + "Use `cols=2` and pass four figures for a 2x2 layout." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "# One subplot per country\n", + "fig_us = xpx(population.sel(country=\"United States\")).bar(title=\"United States\")\n", + "fig_cn = xpx(population.sel(country=\"China\")).bar(title=\"China\")\n", + "fig_de = xpx(population.sel(country=\"Germany\")).bar(title=\"Germany\")\n", + "fig_br = xpx(population.sel(country=\"Brazil\")).bar(title=\"Brazil\")\n", + "\n", + "grid = subplots(fig_us, fig_cn, fig_de, fig_br, cols=2)\n", + "grid.update_layout(height=500, showlegend=False)\n", + "grid" + ] + }, + { + "cell_type": "markdown", + "id": "32", + "metadata": {}, + "source": [ + "### Mixed Chart Types\n", + "\n", + "Each cell can use a different chart type. Subplot titles fall back to the\n", + "y-axis label when no explicit title is set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [], + "source": [ + "# No explicit title — subplot titles come from the y-axis label (DataArray name)\n", + "pop_bar = xpx(us_pop).bar()\n", + "gdp_line = xpx(us_gdp).line()\n", + "life_scatter = xpx(us_life).scatter()\n", + "\n", + "grid = subplots(pop_bar, gdp_line, life_scatter, cols=3)\n", + "grid.update_layout(height=350, showlegend=False)\n", + "grid" + ] + }, + { + "cell_type": "markdown", + "id": "34", + "metadata": {}, + "source": [ + "### With Facets\n", + "\n", + "Faceted figures can be composed — each figure's internal subplots are remapped into the grid cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "# Faceted bar on top, faceted line below\n", + "pop_faceted = xpx(population).bar(facet_col=\"country\")\n", + "gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n", + "\n", + "grid = subplots(pop_faceted, gdp_faceted, cols=1)\n", + "grid.update_layout(height=600, showlegend=False)\n", + "grid" + ] + }, + { + "cell_type": "markdown", + "id": "36", + "metadata": {}, "source": [ "---\n", "\n", @@ -450,7 +573,7 @@ }, { "cell_type": "markdown", - "id": "28", + "id": "37", "metadata": {}, "source": [ "### overlay: Mismatched Facet Structure\n", @@ -461,7 +584,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -479,7 +602,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "39", "metadata": {}, "source": [ "### overlay: Animated Overlay on Static Base\n", @@ -490,7 +613,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +631,7 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "41", "metadata": {}, "source": [ "### overlay: Mismatched Animation Frames\n", @@ -519,7 +642,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "42", "metadata": {}, "outputs": [], "source": [ @@ -535,7 +658,7 @@ }, { "cell_type": "markdown", - "id": "34", + "id": "43", "metadata": {}, "source": [ "### add_secondary_y: Mismatched Facet Structure\n", @@ -546,7 +669,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "44", "metadata": {}, "outputs": [], "source": [ @@ -564,7 +687,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "45", "metadata": {}, "source": [ "### add_secondary_y: Animated Secondary on Static Base\n", @@ -575,7 +698,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -593,7 +716,7 @@ }, { "cell_type": "markdown", - "id": "38", + "id": "47", "metadata": {}, "source": [ "### add_secondary_y: Mismatched Animation Frames" @@ -602,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "48", "metadata": {}, "outputs": [], "source": [ @@ -618,7 +741,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "49", "metadata": {}, "source": [ "## Summary\n", @@ -626,7 +749,8 @@ "| Function | Facets | Animation | Static + Animated |\n", "|----------|--------|-----------|-------------------|\n", "| `overlay` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", - "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |" + "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |\n", + "| `subplots` | Yes (remapped into cells) | No | N/A |" ] } ], diff --git a/tests/test_figures.py b/tests/test_figures.py index dacf19d..254a3bd 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -1,4 +1,4 @@ -"""Tests for the figures module (overlay, add_secondary_y).""" +"""Tests for the figures module (overlay, add_secondary_y, subplots).""" from __future__ import annotations @@ -9,7 +9,7 @@ import pytest import xarray as xr -from xarray_plotly import add_secondary_y, overlay, xpx +from xarray_plotly import add_secondary_y, overlay, subplots, xpx class TestOverlayBasic: @@ -616,3 +616,309 @@ def test_secondary_not_modified(self) -> None: # Secondary traces should still use original yaxis assert secondary.data[0].yaxis == original_yaxis + + +class TestLegendVisibility: + """Tests that combined figures preserve legend visibility.""" + + def test_overlay_single_trace_figures_with_names(self) -> None: + """Overlay of named single-trace figures shows legend.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="a") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="b") + + fig1 = xpx(da1).line() + fig1.update_traces(name="Series A") + fig2 = xpx(da2).line() + fig2.update_traces(name="Series B") + + combined = overlay(fig1, fig2) + + assert combined.data[0].showlegend is True + assert combined.data[1].showlegend is True + + def test_overlay_unnamed_traces_get_yaxis_title(self) -> None: + """Overlay of unnamed traces derives names from y-axis titles.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure") + + fig1 = xpx(da1).line() + fig2 = xpx(da2).line() + + combined = overlay(fig1, fig2) + + # Names derived from y-axis titles (DataArray names) + assert combined.data[0].name == "Temperature" + assert combined.data[1].name == "Pressure" + assert combined.data[0].showlegend is True + assert combined.data[1].showlegend is True + + def test_overlay_same_name_disambiguated(self) -> None: + """Overlay of figures with same y-axis title gets numeric suffix.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="value") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="value") + + fig1 = xpx(da1).line() + fig2 = xpx(da2).line() + + combined = overlay(fig1, fig2) + + assert combined.data[0].name == "value (1)" + assert combined.data[1].name == "value (2)" + + def test_overlay_multi_trace_deduplicates_legend(self) -> None: + """Overlay of multi-trace figures deduplicates shared legendgroups.""" + da = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "cat"], + coords={"cat": ["A", "B", "C"]}, + ) + fig1 = xpx(da).area() + fig2 = xpx(da).line() + + combined = overlay(fig1, fig2) + + # First occurrence of each legendgroup should show, duplicates hidden + from collections import defaultdict + + groups: dict[str, list[bool]] = defaultdict(list) + for trace in combined.data: + lg = trace.legendgroup + groups[lg].append(trace.showlegend is True) + + for lg, flags in groups.items(): + assert flags.count(True) == 1, f"legendgroup {lg!r} has {flags.count(True)} visible" + + def test_add_secondary_y_single_trace_with_names(self) -> None: + """add_secondary_y of named single-trace figures shows legend.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="temp") + da2 = xr.DataArray([100, 200, 300], dims=["x"], name="precip") + + fig1 = xpx(da1).line() + fig1.update_traces(name="Temperature") + fig2 = xpx(da2).bar() + fig2.update_traces(name="Precipitation") + + combined = add_secondary_y(fig1, fig2) + + assert combined.data[0].showlegend is True + assert combined.data[1].showlegend is True + + def test_overlay_faceted_legendgroup_dedup(self) -> None: + """Faceted overlay keeps only one showlegend=True per legendgroup.""" + da = xr.DataArray( + np.random.rand(10, 2, 2), + dims=["x", "cat", "facet"], + coords={"cat": ["A", "B"], "facet": ["left", "right"]}, + ) + fig1 = xpx(da).area(facet_col="facet") + fig2 = xpx(da).line(facet_col="facet") + + combined = overlay(fig1, fig2) + + # Check each legendgroup has at least one showlegend=True + from collections import defaultdict + + groups: dict[str, list[bool]] = defaultdict(list) + for trace in combined.data: + lg = trace.legendgroup or "" + if lg: + groups[lg].append(trace.showlegend is True) + + for lg, flags in groups.items(): + assert any(flags), f"legendgroup {lg!r} has no showlegend=True trace" + + def test_overlay_animation_frames_preserve_style(self) -> None: + """Animation frame traces keep legend and color from fig.data.""" + da = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "time"], + coords={"time": [0, 1, 2]}, + name="Population", + ) + da_smooth = da.rolling(x=3, center=True).mean() + da_smooth.name = "Smoothed" + + fig1 = xpx(da).bar(animation_frame="time") + fig1.update_traces(marker={"color": "steelblue"}) + fig2 = xpx(da_smooth).line(animation_frame="time") + fig2.update_traces(line={"color": "red"}) + + combined = overlay(fig1, fig2) + + for frame in combined.frames: + for i, ft in enumerate(frame.data): + src = combined.data[i] + assert ft.name == src.name + assert ft.showlegend == src.showlegend + assert ft.legendgroup == src.legendgroup + # Bar trace should keep steelblue + assert frame.data[0].marker.color == "steelblue" + # Line trace should keep red + assert frame.data[1].line.color == "red" + + +class TestAnimationAxisRanges: + """Tests for _fix_animation_axis_ranges.""" + + def test_datetime_x_axis_not_corrupted(self) -> None: + """datetime64 x-axis should be left on autorange, not cast to float epochs.""" + dates = np.array(["2020-01-01", "2020-06-01", "2021-01-01"], dtype="datetime64[ns]") + da = xr.DataArray( + np.random.rand(3, 2), + dims=["date", "cat"], + coords={"date": dates, "cat": ["A", "B"]}, + name="value", + ) + fig1 = xpx(da).line(animation_frame="cat") + fig2 = xpx(da).scatter(animation_frame="cat") + combined = overlay(fig1, fig2) + + # x-axis range should NOT be set (dates left to autorange) + assert combined.layout.xaxis.range is None + + def test_bar_zero_baseline(self) -> None: + """Bar chart y-axis range should include zero.""" + da = xr.DataArray( + np.array([[100, 200], [150, 250]]), + dims=["x", "frame"], + name="val", + ) + fig = xpx(da).bar(animation_frame="frame") + # After overlay (which triggers _fix_animation_axis_ranges) + combined = overlay(fig, xpx(da).line(animation_frame="frame")) + + lo, _hi = combined.layout.yaxis.range + assert lo <= 0, f"Bar y-axis range should include 0, got lo={lo}" + + +class TestSubplotsBasic: + """Basic tests for subplots function.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + self.da2 = xr.DataArray([10, 20, 30], dims=["x"], name="Rainfall") + self.da3 = xr.DataArray([100, 200, 300], dims=["x"], name="Wind") + + def test_single_figure(self) -> None: + fig = xpx(self.da1).line() + grid = subplots(fig) + assert len(grid.data) == 1 + + def test_two_figures_one_column(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = xpx(self.da2).bar() + grid = subplots(fig1, fig2, cols=1) + assert len(grid.data) == 2 + # Should be on different y-axes (different rows) + assert grid.data[0].yaxis != grid.data[1].yaxis + + def test_two_figures_two_columns(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = xpx(self.da2).bar() + grid = subplots(fig1, fig2, cols=2) + assert len(grid.data) == 2 + assert grid.data[0].xaxis != grid.data[1].xaxis + + def test_three_figures_two_columns(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = xpx(self.da2).bar() + fig3 = xpx(self.da3).scatter() + grid = subplots(fig1, fig2, fig3, cols=2) + assert len(grid.data) == 3 + + def test_trace_count_preserved(self) -> None: + da_multi = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "cat"], + coords={"cat": ["A", "B", "C"]}, + ) + fig1 = xpx(da_multi).line() # 3 traces + fig2 = xpx(self.da1).bar() # 1 trace + grid = subplots(fig1, fig2, cols=2) + assert len(grid.data) == len(fig1.data) + len(fig2.data) + + def test_with_empty_figure(self) -> None: + fig1 = xpx(self.da1).line() + fig2 = go.Figure() + grid = subplots(fig1, fig2, cols=2) + assert len(grid.data) == 1 + + +class TestSubplotsTitles: + """Tests for subplot title derivation.""" + + def test_titles_from_figure_title(self) -> None: + da = xr.DataArray([1, 2, 3], dims=["x"], name="val") + fig1 = xpx(da).line(title="My Title") + fig2 = xpx(da).bar(title="Other Title") + grid = subplots(fig1, fig2, cols=2) + titles = [ann.text for ann in grid.layout.annotations] + assert titles == ["My Title", "Other Title"] + + def test_titles_from_yaxis_label(self) -> None: + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure") + fig1 = xpx(da1).line() + fig2 = xpx(da2).line() + grid = subplots(fig1, fig2, cols=2) + titles = [ann.text for ann in grid.layout.annotations] + assert titles == ["Temperature", "Pressure"] + + def test_titles_fallback_empty(self) -> None: + grid = subplots(go.Figure(), go.Figure(), cols=2) + # No annotations are created for empty titles + titles = [ann.text for ann in grid.layout.annotations] + assert titles == [] + + +class TestSubplotsAxisConfig: + """Tests for axis configuration copying.""" + + def test_axis_titles_copied(self) -> None: + da = xr.DataArray([1, 2, 3], dims=["time"], name="Temperature") + fig = xpx(da).line() + grid = subplots(fig) + assert grid.layout.yaxis.title.text == "Temperature" + assert grid.layout.xaxis.title.text == "time" + + +class TestSubplotsValidation: + """Tests for subplots input validation.""" + + def test_empty_raises(self) -> None: + with pytest.raises(ValueError, match="At least one figure"): + subplots() + + def test_invalid_cols_raises(self) -> None: + with pytest.raises(ValueError, match="cols must be >= 1"): + subplots(go.Figure(), cols=0) + + def test_faceted_figures_stacked(self) -> None: + """Faceted figures can be stacked in a subplot grid.""" + da = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "facet"], + coords={"facet": ["A", "B", "C"]}, + ) + fig1 = xpx(da).bar(facet_col="facet") + fig2 = xpx(da).line(facet_col="facet") + grid = subplots(fig1, fig2, cols=1) + # 3 bar traces + 3 line traces + assert len(grid.data) == 6 + # All traces should have unique axis assignments + axes = {(t.xaxis, t.yaxis) for t in grid.data} + assert len(axes) == 6 + + def test_animated_figure_raises(self) -> None: + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "time"]) + fig = xpx(da).line(animation_frame="time") + with pytest.raises(ValueError, match="animation frames"): + subplots(fig) + + def test_source_not_modified(self) -> None: + da = xr.DataArray([1, 2, 3], dims=["x"], name="val") + fig = xpx(da).line() + original_count = len(fig.data) + _ = subplots(fig, fig, cols=2) + assert len(fig.data) == original_count diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index 7cefbb5..c526173 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -56,6 +56,7 @@ from xarray_plotly.figures import ( add_secondary_y, overlay, + subplots, update_traces, ) @@ -66,6 +67,7 @@ "auto", "config", "overlay", + "subplots", "update_traces", "xpx", ] diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 6ed743d..8a4bf00 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -13,6 +13,188 @@ import plotly.graph_objects as go +def _get_yaxis_title(fig: go.Figure) -> str: + """Extract the primary y-axis title text from a figure. + + Args: + fig: A Plotly figure. + + Returns: + The y-axis title text, or empty string if not set. + """ + try: + return fig.layout.yaxis.title.text or "" + except AttributeError: + return "" + + +def _ensure_legend_visibility( + combined: go.Figure, + source_figs: list[go.Figure], + trace_slices: list[slice], +) -> None: + """Fix legend visibility on a combined figure. + + Handles three problems that arise when combining Plotly Express figures: + + 1. **Unnamed traces** — PX sets ``name=""`` on single-trace (no color) + figures. We derive a name from each source figure's y-axis title. + 2. **Hidden named traces** — PX sets ``showlegend=False`` on single-trace + figures. We ensure at least one trace per ``legendgroup`` (or each + ungrouped named trace) has ``showlegend=True``. + 3. **Duplicate legend entries** — when two source figures share the same + ``legendgroup`` names, we deduplicate so only the first trace per + group shows in the legend. + + Args: + combined: The combined Plotly figure (mutated in place). + source_figs: The original source figures, in trace order. + trace_slices: Slices into ``combined.data`` for each source figure. + """ + from collections import defaultdict + + # --- Step 1: label unnamed traces from source y-axis titles ----------- + labels = [_get_yaxis_title(f) for f in source_figs] + + # If all labels are non-empty and identical, disambiguate + unique_labels = {lb for lb in labels if lb} + if len(unique_labels) == 1 and all(lb for lb in labels): + labels = [f"{labels[0]} ({i + 1})" for i in range(len(labels))] + + for label, sl in zip(labels, trace_slices, strict=False): + if not label: + continue + for trace in combined.data[sl]: + if not getattr(trace, "name", None): + trace.name = label + trace.legendgroup = label + + # --- Step 2 & 3: fix showlegend per legendgroup ----------------------- + grouped: dict[str, list[Any]] = defaultdict(list) + ungrouped: list[Any] = [] + + for trace in combined.data: + lg = getattr(trace, "legendgroup", None) or "" + if lg: + grouped[lg].append(trace) + else: + ungrouped.append(trace) + + for traces in grouped.values(): + has_visible = False + for t in traces: + if has_visible: + # Deduplicate: only first keeps showlegend + t.showlegend = False + elif getattr(t, "name", None): + t.showlegend = True + has_visible = True + + # Ungrouped traces with a name should show in the legend + for trace in ungrouped: + if getattr(trace, "name", None): + trace.showlegend = True + + # --- Step 4: propagate style properties to animation frame traces ------ + # When Plotly animates, frame trace data overwrites fig.data properties. + # PX frame traces carry name="", showlegend=False and default colors, + # discarding any styling the user applied via update_traces() before + # combining. Propagate display properties from fig.data into every frame. + _STYLE_ATTRS = ("name", "legendgroup", "showlegend", "marker", "line", "opacity") + for frame in combined.frames or []: + for i, frame_trace in enumerate(frame.data): + if i < len(combined.data): + src = combined.data[i] + for attr in _STYLE_ATTRS: + src_val = getattr(src, attr, None) + if src_val is not None: + setattr(frame_trace, attr, src_val) + + +def _fix_animation_axis_ranges(fig: go.Figure) -> None: + """Set axis ranges to encompass data across all animation frames. + + Plotly.js computes autorange from ``fig.data`` only and does not + recalculate during animation. When different frames have very different + data ranges (e.g. population of Brazil vs China), values can go off-screen. + This function computes the global min/max for each axis across all frames + and sets explicit ranges on the layout. + + Only numeric axes are handled; categorical/date axes are left to autorange. + + Args: + fig: A Plotly figure with animation frames (mutated in place). + """ + import numpy as np + + if not fig.frames: + return + + from collections import defaultdict + + # Collect numeric y-values per axis across all traces (fig.data + frames) + y_by_axis: dict[str, list[float]] = defaultdict(list) + x_by_axis: dict[str, list[float]] = defaultdict(list) + + # Track which axes have bar traces (for zero-baseline clamping) + y_has_vbar: set[str] = set() # vertical bars → y-axis includes 0 + x_has_hbar: set[str] = set() # horizontal bars → x-axis includes 0 + + for trace in _iter_all_traces(fig): + yaxis = getattr(trace, "yaxis", None) or "y" + xaxis = getattr(trace, "xaxis", None) or "x" + + # Track bar orientations + if getattr(trace, "type", None) == "bar": + orientation = getattr(trace, "orientation", None) or "v" + if orientation == "h": + x_has_hbar.add(xaxis) + else: + y_has_vbar.add(yaxis) + + for data_attr, axis_ref, by_axis in [ + ("y", yaxis, y_by_axis), + ("x", xaxis, x_by_axis), + ]: + vals = getattr(trace, data_attr, None) + if vals is None: + continue + arr = np.asarray(vals) + # Skip datetime/timedelta — leave those axes on autorange + if np.issubdtype(arr.dtype, np.datetime64) or np.issubdtype(arr.dtype, np.timedelta64): + continue + try: + arr = arr.astype(float) + finite = arr[np.isfinite(arr)] + if len(finite): + by_axis[axis_ref].extend(finite.tolist()) + except (ValueError, TypeError): + pass # Non-numeric (categorical) — skip + + # Apply ranges to layout + for axis_ref, values in y_by_axis.items(): + if not values: + continue + lo, hi = min(values), max(values) + if axis_ref in y_has_vbar: + lo = min(lo, 0.0) + hi = max(hi, 0.0) + pad = (hi - lo) * 0.05 or 1 # 5% padding + layout_prop = "yaxis" if axis_ref == "y" else f"yaxis{axis_ref[1:]}" + fig.layout[layout_prop].range = [lo - pad, hi + pad] + + for axis_ref, values in x_by_axis.items(): + if not values: + continue + lo, hi = min(values), max(values) + if axis_ref in x_has_hbar: + lo = min(lo, 0.0) + hi = max(hi, 0.0) + pad = (hi - lo) * 0.05 or 1 + layout_prop = "xaxis" if axis_ref == "x" else f"xaxis{axis_ref[1:]}" + fig.layout[layout_prop].range = [lo - pad, hi + pad] + + def _iter_all_traces(fig: go.Figure) -> Iterator[Any]: """Iterate over all traces in a figure, including animation frames. @@ -194,17 +376,11 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure: _validate_compatible_structure(base, overlay) _validate_animation_compatibility(base, overlay) - # Create new figure with base's layout - combined = go.Figure(layout=copy.deepcopy(base.layout)) - - # Add all traces from base - for trace in base.data: - combined.add_trace(copy.deepcopy(trace)) - - # Add all traces from overlays + # Create new figure with base's layout and all traces + all_traces = [copy.deepcopy(t) for t in base.data] for overlay in overlays: - for trace in overlay.data: - combined.add_trace(copy.deepcopy(trace)) + all_traces.extend(copy.deepcopy(t) for t in overlay.data) + combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout)) # Handle animation frames if base.frames: @@ -213,6 +389,17 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure: merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts) combined.frames = merged_frames + # Build trace slices for legend fix + source_figs = [base, *overlays] + slices: list[slice] = [] + offset = 0 + for fig in source_figs: + n = len(fig.data) + slices.append(slice(offset, offset + n)) + offset += n + + _ensure_legend_visibility(combined, source_figs, slices) + _fix_animation_axis_ranges(combined) return combined @@ -315,19 +502,15 @@ def add_secondary_y( rightmost_x = max(x_for_y.values(), key=lambda x: int(x[1:]) if x != "x" else 1) rightmost_primary_y = next(y for y, x in x_for_y.items() if x == rightmost_x) - # Create new figure with base's layout - combined = go.Figure(layout=copy.deepcopy(base.layout)) - - # Add all traces from base (primary y-axis) - for trace in base.data: - combined.add_trace(copy.deepcopy(trace)) - - # Add all traces from secondary, remapped to secondary y-axes + # Build all traces: base (primary) + secondary (remapped to secondary y-axes) + all_traces = [copy.deepcopy(t) for t in base.data] for trace in secondary.data: trace_copy = copy.deepcopy(trace) original_yaxis = getattr(trace_copy, "yaxis", None) or "y" trace_copy.yaxis = y_mapping[original_yaxis] - combined.add_trace(trace_copy) + all_traces.append(trace_copy) + + combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout)) # Get the rightmost secondary y-axis name for linking rightmost_secondary_y = y_mapping[rightmost_primary_y] @@ -368,6 +551,14 @@ def add_secondary_y( merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping) combined.frames = merged_frames + base_n = len(base.data) + sec_n = len(secondary.data) + _ensure_legend_visibility( + combined, + [base, secondary], + [slice(0, base_n), slice(base_n, base_n + sec_n)], + ) + _fix_animation_axis_ranges(combined) return combined @@ -426,6 +617,272 @@ def _merge_secondary_y_frames( return merged_frames +def _get_figure_title(fig: go.Figure) -> str: + """Extract a display title from a figure for use as a subplot title. + + Checks, in order: the figure's title, then the y-axis title. + + Args: + fig: A Plotly figure. + + Returns: + A title string, or empty string if nothing is set. + """ + try: + title = fig.layout.title.text + if isinstance(title, str) and title: + return title + except AttributeError: + pass + return _get_yaxis_title(fig) + + +def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure: + """Arrange multiple figures into a subplot grid. + + Creates a new figure with each input figure placed in its own cell. + Figures may contain internal subplots (facets) — their axes are remapped + to fit within the grid cell. Subplot titles are derived from each + figure's title or y-axis label. + + Args: + *figs: One or more Plotly figures to arrange. + cols: Number of columns in the grid. Rows are computed automatically. + + Returns: + A new figure with subplot grid. + + Raises: + ValueError: If no figures are provided, cols < 1, or a figure has + animation frames. + + Example: + >>> import numpy as np + >>> import xarray as xr + >>> from xarray_plotly import xpx, subplots + >>> + >>> temp = xr.DataArray([20, 22, 25], dims=["time"], name="Temperature") + >>> rain = xr.DataArray([0, 5, 12], dims=["time"], name="Rainfall") + >>> fig1 = xpx(temp).line() + >>> fig2 = xpx(rain).bar() + >>> grid = subplots(fig1, fig2, cols=2) + """ + import math + + import plotly.graph_objects as go + + if not figs: + raise ValueError("At least one figure is required.") + if cols < 1: + raise ValueError(f"cols must be >= 1, got {cols}.") + + for i, fig in enumerate(figs): + if fig.frames: + raise ValueError( + f"Figure at position {i} has animation frames. " + "Animated figures are not supported in subplots()." + ) + + rows = math.ceil(len(figs) / cols) + combined = go.Figure() + + # Grid spacing + h_gap = 0.05 + v_gap = 0.08 + cell_w = (1.0 - h_gap * (cols - 1)) / cols + cell_h = (1.0 - v_gap * (rows - 1)) / rows + + next_x_num = 1 + next_y_num = 1 + + for i, fig in enumerate(figs): + row = i // cols # 0-indexed, top to bottom + col = i % cols + + # Cell boundaries (clamped to [0, 1]) + cell_x0 = max(0.0, col * (cell_w + h_gap)) + cell_x1 = min(1.0, cell_x0 + cell_w) + cell_y1 = min(1.0, 1.0 - row * (cell_h + v_gap)) # top-down + cell_y0 = max(0.0, cell_y1 - cell_h) + + # Build axis remapping: old axis ref → new axis ref + axis_map, next_x_num, next_y_num = _remap_figure_axes( + fig, combined, next_x_num, next_y_num, cell_x0, cell_x1, cell_y0, cell_y1 + ) + + # Add traces with remapped axis refs + for trace in fig.data: + tc = copy.deepcopy(trace) + old_x = getattr(tc, "xaxis", None) or "x" + old_y = getattr(tc, "yaxis", None) or "y" + tc.xaxis = axis_map[old_x]["new_x"] + tc.yaxis = axis_map[old_y]["new_y"] + combined.add_trace(tc) + + # Add subplot title as annotation + title = _get_figure_title(fig) + if title: + combined.add_annotation( + text=f"{title}", + x=(cell_x0 + cell_x1) / 2, + y=cell_y1, + xref="paper", + yref="paper", + xanchor="center", + yanchor="bottom", + showarrow=False, + font={"size": 14}, + ) + + return combined + + +# Axis properties safe to copy between figures (display-only, not structural). +_AXIS_PROPS_TO_COPY = ( + "title", + "type", + "tickformat", + "ticksuffix", + "tickprefix", + "dtick", + "tick0", + "nticks", + "showgrid", + "gridcolor", + "gridwidth", + "autorange", + "range", + "zeroline", + "zerolinecolor", + "zerolinewidth", + "showticklabels", +) + + +def _axis_layout_key(ref: str) -> str: + """Convert axis reference to layout property name. + + ``"x"`` → ``"xaxis"``, ``"x2"`` → ``"xaxis2"``, + ``"y"`` → ``"yaxis"``, ``"y3"`` → ``"yaxis3"``. + """ + if ref in ("x", "y"): + return f"{ref}axis" + prefix = ref[0] # "x" or "y" + num = ref[1:] + return f"{prefix}axis{num}" + + +def _new_axis_ref(prefix: str, num: int) -> str: + """Build an axis reference string. ``_new_axis_ref("x", 1)`` → ``"x"``, ``("x", 3)`` → ``"x3"``.""" + return prefix if num == 1 else f"{prefix}{num}" + + +def _remap_figure_axes( + fig: go.Figure, + combined: go.Figure, + next_x_num: int, + next_y_num: int, + cell_x0: float, + cell_x1: float, + cell_y0: float, + cell_y1: float, +) -> tuple[dict[str, dict[str, str]], int, int]: + """Remap a figure's axes into a grid cell, adding axis configs to the combined layout. + + Args: + fig: Source figure. + combined: Target combined figure (mutated — axis configs added to layout). + next_x_num: Next available x-axis number. + next_y_num: Next available y-axis number. + cell_x0, cell_x1: Horizontal cell bounds in paper coordinates. + cell_y0, cell_y1: Vertical cell bounds in paper coordinates. + + Returns: + Tuple of (axis_map, next_x_num, next_y_num). + axis_map maps old axis refs to ``{"new_x": ...}`` or ``{"new_y": ...}``. + """ + cell_w = cell_x1 - cell_x0 + cell_h = cell_y1 - cell_y0 + src_layout = fig.layout.to_plotly_json() + + x_remap: dict[str, str] = {} + y_remap: dict[str, str] = {} + + # Get all unique axis refs + x_refs: set[str] = set() + y_refs: set[str] = set() + for trace in fig.data: + x_refs.add(getattr(trace, "xaxis", None) or "x") + y_refs.add(getattr(trace, "yaxis", None) or "y") + + # Remap x-axes + for old_xref in sorted(x_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1): + new_xref = _new_axis_ref("x", next_x_num) + x_remap[old_xref] = new_xref + + src_config = src_layout.get(_axis_layout_key(old_xref), {}) + src_domain = src_config.get("domain", [0.0, 1.0]) + new_domain = [ + max(0.0, cell_x0 + src_domain[0] * cell_w), + min(1.0, cell_x0 + src_domain[1] * cell_w), + ] + + new_config: dict[str, Any] = {"domain": new_domain} + for prop in _AXIS_PROPS_TO_COPY: + if prop in src_config: + new_config[prop] = src_config[prop] + + combined.layout[_axis_layout_key(new_xref)] = new_config + next_x_num += 1 + + # Remap y-axes + for old_yref in sorted(y_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1): + new_yref = _new_axis_ref("y", next_y_num) + y_remap[old_yref] = new_yref + + src_config = src_layout.get(_axis_layout_key(old_yref), {}) + src_domain = src_config.get("domain", [0.0, 1.0]) + new_domain = [ + max(0.0, cell_y0 + src_domain[0] * cell_h), + min(1.0, cell_y0 + src_domain[1] * cell_h), + ] + + new_config = {"domain": new_domain} + for prop in _AXIS_PROPS_TO_COPY: + if prop in src_config: + new_config[prop] = src_config[prop] + + combined.layout[_axis_layout_key(new_yref)] = new_config + next_y_num += 1 + + # Set anchors between paired axes + for trace in fig.data: + old_x = getattr(trace, "xaxis", None) or "x" + old_y = getattr(trace, "yaxis", None) or "y" + combined.layout[_axis_layout_key(x_remap[old_x])]["anchor"] = y_remap[old_y] + combined.layout[_axis_layout_key(y_remap[old_y])]["anchor"] = x_remap[old_x] + + # Propagate matches relationships + for old_ref, new_ref in x_remap.items(): + src_config = src_layout.get(_axis_layout_key(old_ref), {}) + if "matches" in src_config and src_config["matches"] in x_remap: + combined.layout[_axis_layout_key(new_ref)]["matches"] = x_remap[src_config["matches"]] + + for old_ref, new_ref in y_remap.items(): + src_config = src_layout.get(_axis_layout_key(old_ref), {}) + if "matches" in src_config and src_config["matches"] in y_remap: + combined.layout[_axis_layout_key(new_ref)]["matches"] = y_remap[src_config["matches"]] + + # Build combined return mapping + result: dict[str, dict[str, str]] = {} + for old_x, new_x in x_remap.items(): + result[old_x] = {"new_x": new_x} + for old_y, new_y in y_remap.items(): + result[old_y] = {"new_y": new_y} + + return result, next_x_num, next_y_num + + def update_traces( fig: go.Figure, selector: dict[str, Any] | None = None, **kwargs: Any ) -> go.Figure: