diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb
index 6755c20..a72d752 100644
--- a/docs/examples/combining.ipynb
+++ b/docs/examples/combining.ipynb
@@ -10,7 +10,8 @@
"xarray-plotly provides helper functions to combine multiple figures:\n",
"\n",
"- **`overlay`**: Overlay traces on the same axes\n",
- "- **`add_secondary_y`**: Plot with two independent y-axes"
+ "- **`add_secondary_y`**: Plot with two independent y-axes\n",
+ "- **`subplots`**: Arrange independent figures in a grid"
]
},
{
@@ -23,7 +24,7 @@
"import plotly.express as px\n",
"import xarray as xr\n",
"\n",
- "from xarray_plotly import add_secondary_y, config, overlay, xpx\n",
+ "from xarray_plotly import add_secondary_y, config, overlay, subplots, xpx\n",
"\n",
"config.notebook()"
]
@@ -440,6 +441,128 @@
"cell_type": "markdown",
"id": "27",
"metadata": {},
+ "source": [
+ "## subplots\n",
+ "\n",
+ "Arrange independent figures side-by-side in a grid. Each figure gets its own\n",
+ "subplot cell with axes and title automatically derived from the source figure."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28",
+ "metadata": {},
+ "source": [
+ "### Different Variables Side by Side"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "29",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# One figure per variable, arranged in a row\n",
+ "us_pop = population.sel(country=\"United States\")\n",
+ "us_gdp = gdp_per_capita.sel(country=\"United States\")\n",
+ "us_life = life_expectancy.sel(country=\"United States\")\n",
+ "\n",
+ "pop_fig = xpx(us_pop).bar(title=\"Population\")\n",
+ "gdp_fig = xpx(us_gdp).line(title=\"GDP per Capita\")\n",
+ "life_fig = xpx(us_life).line(title=\"Life Expectancy\")\n",
+ "\n",
+ "grid = subplots(pop_fig, gdp_fig, life_fig, cols=3)\n",
+ "grid.update_layout(title=\"United States Overview\", height=350, showlegend=False)\n",
+ "grid"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "30",
+ "metadata": {},
+ "source": [
+ "### 2x2 Grid\n",
+ "\n",
+ "Use `cols=2` and pass four figures for a 2x2 layout."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "31",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# One subplot per country\n",
+ "fig_us = xpx(population.sel(country=\"United States\")).bar(title=\"United States\")\n",
+ "fig_cn = xpx(population.sel(country=\"China\")).bar(title=\"China\")\n",
+ "fig_de = xpx(population.sel(country=\"Germany\")).bar(title=\"Germany\")\n",
+ "fig_br = xpx(population.sel(country=\"Brazil\")).bar(title=\"Brazil\")\n",
+ "\n",
+ "grid = subplots(fig_us, fig_cn, fig_de, fig_br, cols=2)\n",
+ "grid.update_layout(height=500, showlegend=False)\n",
+ "grid"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "32",
+ "metadata": {},
+ "source": [
+ "### Mixed Chart Types\n",
+ "\n",
+ "Each cell can use a different chart type. Subplot titles fall back to the\n",
+ "y-axis label when no explicit title is set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "33",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# No explicit title — subplot titles come from the y-axis label (DataArray name)\n",
+ "pop_bar = xpx(us_pop).bar()\n",
+ "gdp_line = xpx(us_gdp).line()\n",
+ "life_scatter = xpx(us_life).scatter()\n",
+ "\n",
+ "grid = subplots(pop_bar, gdp_line, life_scatter, cols=3)\n",
+ "grid.update_layout(height=350, showlegend=False)\n",
+ "grid"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "34",
+ "metadata": {},
+ "source": [
+ "### With Facets\n",
+ "\n",
+ "Faceted figures can be composed — each figure's internal subplots are remapped into the grid cell."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "35",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Faceted bar on top, faceted line below\n",
+ "pop_faceted = xpx(population).bar(facet_col=\"country\")\n",
+ "gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n",
+ "\n",
+ "grid = subplots(pop_faceted, gdp_faceted, cols=1)\n",
+ "grid.update_layout(height=600, showlegend=False)\n",
+ "grid"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "36",
+ "metadata": {},
"source": [
"---\n",
"\n",
@@ -450,7 +573,7 @@
},
{
"cell_type": "markdown",
- "id": "28",
+ "id": "37",
"metadata": {},
"source": [
"### overlay: Mismatched Facet Structure\n",
@@ -461,7 +584,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "29",
+ "id": "38",
"metadata": {},
"outputs": [],
"source": [
@@ -479,7 +602,7 @@
},
{
"cell_type": "markdown",
- "id": "30",
+ "id": "39",
"metadata": {},
"source": [
"### overlay: Animated Overlay on Static Base\n",
@@ -490,7 +613,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "31",
+ "id": "40",
"metadata": {},
"outputs": [],
"source": [
@@ -508,7 +631,7 @@
},
{
"cell_type": "markdown",
- "id": "32",
+ "id": "41",
"metadata": {},
"source": [
"### overlay: Mismatched Animation Frames\n",
@@ -519,7 +642,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "33",
+ "id": "42",
"metadata": {},
"outputs": [],
"source": [
@@ -535,7 +658,7 @@
},
{
"cell_type": "markdown",
- "id": "34",
+ "id": "43",
"metadata": {},
"source": [
"### add_secondary_y: Mismatched Facet Structure\n",
@@ -546,7 +669,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "35",
+ "id": "44",
"metadata": {},
"outputs": [],
"source": [
@@ -564,7 +687,7 @@
},
{
"cell_type": "markdown",
- "id": "36",
+ "id": "45",
"metadata": {},
"source": [
"### add_secondary_y: Animated Secondary on Static Base\n",
@@ -575,7 +698,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "37",
+ "id": "46",
"metadata": {},
"outputs": [],
"source": [
@@ -593,7 +716,7 @@
},
{
"cell_type": "markdown",
- "id": "38",
+ "id": "47",
"metadata": {},
"source": [
"### add_secondary_y: Mismatched Animation Frames"
@@ -602,7 +725,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "39",
+ "id": "48",
"metadata": {},
"outputs": [],
"source": [
@@ -618,7 +741,7 @@
},
{
"cell_type": "markdown",
- "id": "40",
+ "id": "49",
"metadata": {},
"source": [
"## Summary\n",
@@ -626,7 +749,8 @@
"| Function | Facets | Animation | Static + Animated |\n",
"|----------|--------|-----------|-------------------|\n",
"| `overlay` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n",
- "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |"
+ "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |\n",
+ "| `subplots` | Yes (remapped into cells) | No | N/A |"
]
}
],
diff --git a/tests/test_figures.py b/tests/test_figures.py
index dacf19d..254a3bd 100644
--- a/tests/test_figures.py
+++ b/tests/test_figures.py
@@ -1,4 +1,4 @@
-"""Tests for the figures module (overlay, add_secondary_y)."""
+"""Tests for the figures module (overlay, add_secondary_y, subplots)."""
from __future__ import annotations
@@ -9,7 +9,7 @@
import pytest
import xarray as xr
-from xarray_plotly import add_secondary_y, overlay, xpx
+from xarray_plotly import add_secondary_y, overlay, subplots, xpx
class TestOverlayBasic:
@@ -616,3 +616,309 @@ def test_secondary_not_modified(self) -> None:
# Secondary traces should still use original yaxis
assert secondary.data[0].yaxis == original_yaxis
+
+
+class TestLegendVisibility:
+ """Tests that combined figures preserve legend visibility."""
+
+ def test_overlay_single_trace_figures_with_names(self) -> None:
+ """Overlay of named single-trace figures shows legend."""
+ da1 = xr.DataArray([1, 2, 3], dims=["x"], name="a")
+ da2 = xr.DataArray([4, 5, 6], dims=["x"], name="b")
+
+ fig1 = xpx(da1).line()
+ fig1.update_traces(name="Series A")
+ fig2 = xpx(da2).line()
+ fig2.update_traces(name="Series B")
+
+ combined = overlay(fig1, fig2)
+
+ assert combined.data[0].showlegend is True
+ assert combined.data[1].showlegend is True
+
+ def test_overlay_unnamed_traces_get_yaxis_title(self) -> None:
+ """Overlay of unnamed traces derives names from y-axis titles."""
+ da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature")
+ da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure")
+
+ fig1 = xpx(da1).line()
+ fig2 = xpx(da2).line()
+
+ combined = overlay(fig1, fig2)
+
+ # Names derived from y-axis titles (DataArray names)
+ assert combined.data[0].name == "Temperature"
+ assert combined.data[1].name == "Pressure"
+ assert combined.data[0].showlegend is True
+ assert combined.data[1].showlegend is True
+
+ def test_overlay_same_name_disambiguated(self) -> None:
+ """Overlay of figures with same y-axis title gets numeric suffix."""
+ da1 = xr.DataArray([1, 2, 3], dims=["x"], name="value")
+ da2 = xr.DataArray([4, 5, 6], dims=["x"], name="value")
+
+ fig1 = xpx(da1).line()
+ fig2 = xpx(da2).line()
+
+ combined = overlay(fig1, fig2)
+
+ assert combined.data[0].name == "value (1)"
+ assert combined.data[1].name == "value (2)"
+
+ def test_overlay_multi_trace_deduplicates_legend(self) -> None:
+ """Overlay of multi-trace figures deduplicates shared legendgroups."""
+ da = xr.DataArray(
+ np.random.rand(10, 3),
+ dims=["x", "cat"],
+ coords={"cat": ["A", "B", "C"]},
+ )
+ fig1 = xpx(da).area()
+ fig2 = xpx(da).line()
+
+ combined = overlay(fig1, fig2)
+
+ # First occurrence of each legendgroup should show, duplicates hidden
+ from collections import defaultdict
+
+ groups: dict[str, list[bool]] = defaultdict(list)
+ for trace in combined.data:
+ lg = trace.legendgroup
+ groups[lg].append(trace.showlegend is True)
+
+ for lg, flags in groups.items():
+ assert flags.count(True) == 1, f"legendgroup {lg!r} has {flags.count(True)} visible"
+
+ def test_add_secondary_y_single_trace_with_names(self) -> None:
+ """add_secondary_y of named single-trace figures shows legend."""
+ da1 = xr.DataArray([1, 2, 3], dims=["x"], name="temp")
+ da2 = xr.DataArray([100, 200, 300], dims=["x"], name="precip")
+
+ fig1 = xpx(da1).line()
+ fig1.update_traces(name="Temperature")
+ fig2 = xpx(da2).bar()
+ fig2.update_traces(name="Precipitation")
+
+ combined = add_secondary_y(fig1, fig2)
+
+ assert combined.data[0].showlegend is True
+ assert combined.data[1].showlegend is True
+
+ def test_overlay_faceted_legendgroup_dedup(self) -> None:
+ """Faceted overlay keeps only one showlegend=True per legendgroup."""
+ da = xr.DataArray(
+ np.random.rand(10, 2, 2),
+ dims=["x", "cat", "facet"],
+ coords={"cat": ["A", "B"], "facet": ["left", "right"]},
+ )
+ fig1 = xpx(da).area(facet_col="facet")
+ fig2 = xpx(da).line(facet_col="facet")
+
+ combined = overlay(fig1, fig2)
+
+ # Check each legendgroup has at least one showlegend=True
+ from collections import defaultdict
+
+ groups: dict[str, list[bool]] = defaultdict(list)
+ for trace in combined.data:
+ lg = trace.legendgroup or ""
+ if lg:
+ groups[lg].append(trace.showlegend is True)
+
+ for lg, flags in groups.items():
+ assert any(flags), f"legendgroup {lg!r} has no showlegend=True trace"
+
+ def test_overlay_animation_frames_preserve_style(self) -> None:
+ """Animation frame traces keep legend and color from fig.data."""
+ da = xr.DataArray(
+ np.random.rand(10, 3),
+ dims=["x", "time"],
+ coords={"time": [0, 1, 2]},
+ name="Population",
+ )
+ da_smooth = da.rolling(x=3, center=True).mean()
+ da_smooth.name = "Smoothed"
+
+ fig1 = xpx(da).bar(animation_frame="time")
+ fig1.update_traces(marker={"color": "steelblue"})
+ fig2 = xpx(da_smooth).line(animation_frame="time")
+ fig2.update_traces(line={"color": "red"})
+
+ combined = overlay(fig1, fig2)
+
+ for frame in combined.frames:
+ for i, ft in enumerate(frame.data):
+ src = combined.data[i]
+ assert ft.name == src.name
+ assert ft.showlegend == src.showlegend
+ assert ft.legendgroup == src.legendgroup
+ # Bar trace should keep steelblue
+ assert frame.data[0].marker.color == "steelblue"
+ # Line trace should keep red
+ assert frame.data[1].line.color == "red"
+
+
+class TestAnimationAxisRanges:
+ """Tests for _fix_animation_axis_ranges."""
+
+ def test_datetime_x_axis_not_corrupted(self) -> None:
+ """datetime64 x-axis should be left on autorange, not cast to float epochs."""
+ dates = np.array(["2020-01-01", "2020-06-01", "2021-01-01"], dtype="datetime64[ns]")
+ da = xr.DataArray(
+ np.random.rand(3, 2),
+ dims=["date", "cat"],
+ coords={"date": dates, "cat": ["A", "B"]},
+ name="value",
+ )
+ fig1 = xpx(da).line(animation_frame="cat")
+ fig2 = xpx(da).scatter(animation_frame="cat")
+ combined = overlay(fig1, fig2)
+
+ # x-axis range should NOT be set (dates left to autorange)
+ assert combined.layout.xaxis.range is None
+
+ def test_bar_zero_baseline(self) -> None:
+ """Bar chart y-axis range should include zero."""
+ da = xr.DataArray(
+ np.array([[100, 200], [150, 250]]),
+ dims=["x", "frame"],
+ name="val",
+ )
+ fig = xpx(da).bar(animation_frame="frame")
+ # After overlay (which triggers _fix_animation_axis_ranges)
+ combined = overlay(fig, xpx(da).line(animation_frame="frame"))
+
+ lo, _hi = combined.layout.yaxis.range
+ assert lo <= 0, f"Bar y-axis range should include 0, got lo={lo}"
+
+
+class TestSubplotsBasic:
+ """Basic tests for subplots function."""
+
+ @pytest.fixture(autouse=True)
+ def setup(self) -> None:
+ self.da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature")
+ self.da2 = xr.DataArray([10, 20, 30], dims=["x"], name="Rainfall")
+ self.da3 = xr.DataArray([100, 200, 300], dims=["x"], name="Wind")
+
+ def test_single_figure(self) -> None:
+ fig = xpx(self.da1).line()
+ grid = subplots(fig)
+ assert len(grid.data) == 1
+
+ def test_two_figures_one_column(self) -> None:
+ fig1 = xpx(self.da1).line()
+ fig2 = xpx(self.da2).bar()
+ grid = subplots(fig1, fig2, cols=1)
+ assert len(grid.data) == 2
+ # Should be on different y-axes (different rows)
+ assert grid.data[0].yaxis != grid.data[1].yaxis
+
+ def test_two_figures_two_columns(self) -> None:
+ fig1 = xpx(self.da1).line()
+ fig2 = xpx(self.da2).bar()
+ grid = subplots(fig1, fig2, cols=2)
+ assert len(grid.data) == 2
+ assert grid.data[0].xaxis != grid.data[1].xaxis
+
+ def test_three_figures_two_columns(self) -> None:
+ fig1 = xpx(self.da1).line()
+ fig2 = xpx(self.da2).bar()
+ fig3 = xpx(self.da3).scatter()
+ grid = subplots(fig1, fig2, fig3, cols=2)
+ assert len(grid.data) == 3
+
+ def test_trace_count_preserved(self) -> None:
+ da_multi = xr.DataArray(
+ np.random.rand(10, 3),
+ dims=["x", "cat"],
+ coords={"cat": ["A", "B", "C"]},
+ )
+ fig1 = xpx(da_multi).line() # 3 traces
+ fig2 = xpx(self.da1).bar() # 1 trace
+ grid = subplots(fig1, fig2, cols=2)
+ assert len(grid.data) == len(fig1.data) + len(fig2.data)
+
+ def test_with_empty_figure(self) -> None:
+ fig1 = xpx(self.da1).line()
+ fig2 = go.Figure()
+ grid = subplots(fig1, fig2, cols=2)
+ assert len(grid.data) == 1
+
+
+class TestSubplotsTitles:
+ """Tests for subplot title derivation."""
+
+ def test_titles_from_figure_title(self) -> None:
+ da = xr.DataArray([1, 2, 3], dims=["x"], name="val")
+ fig1 = xpx(da).line(title="My Title")
+ fig2 = xpx(da).bar(title="Other Title")
+ grid = subplots(fig1, fig2, cols=2)
+ titles = [ann.text for ann in grid.layout.annotations]
+ assert titles == ["My Title", "Other Title"]
+
+ def test_titles_from_yaxis_label(self) -> None:
+ da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature")
+ da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure")
+ fig1 = xpx(da1).line()
+ fig2 = xpx(da2).line()
+ grid = subplots(fig1, fig2, cols=2)
+ titles = [ann.text for ann in grid.layout.annotations]
+ assert titles == ["Temperature", "Pressure"]
+
+ def test_titles_fallback_empty(self) -> None:
+ grid = subplots(go.Figure(), go.Figure(), cols=2)
+ # No annotations are created for empty titles
+ titles = [ann.text for ann in grid.layout.annotations]
+ assert titles == []
+
+
+class TestSubplotsAxisConfig:
+ """Tests for axis configuration copying."""
+
+ def test_axis_titles_copied(self) -> None:
+ da = xr.DataArray([1, 2, 3], dims=["time"], name="Temperature")
+ fig = xpx(da).line()
+ grid = subplots(fig)
+ assert grid.layout.yaxis.title.text == "Temperature"
+ assert grid.layout.xaxis.title.text == "time"
+
+
+class TestSubplotsValidation:
+ """Tests for subplots input validation."""
+
+ def test_empty_raises(self) -> None:
+ with pytest.raises(ValueError, match="At least one figure"):
+ subplots()
+
+ def test_invalid_cols_raises(self) -> None:
+ with pytest.raises(ValueError, match="cols must be >= 1"):
+ subplots(go.Figure(), cols=0)
+
+ def test_faceted_figures_stacked(self) -> None:
+ """Faceted figures can be stacked in a subplot grid."""
+ da = xr.DataArray(
+ np.random.rand(10, 3),
+ dims=["x", "facet"],
+ coords={"facet": ["A", "B", "C"]},
+ )
+ fig1 = xpx(da).bar(facet_col="facet")
+ fig2 = xpx(da).line(facet_col="facet")
+ grid = subplots(fig1, fig2, cols=1)
+ # 3 bar traces + 3 line traces
+ assert len(grid.data) == 6
+ # All traces should have unique axis assignments
+ axes = {(t.xaxis, t.yaxis) for t in grid.data}
+ assert len(axes) == 6
+
+ def test_animated_figure_raises(self) -> None:
+ da = xr.DataArray(np.random.rand(10, 3), dims=["x", "time"])
+ fig = xpx(da).line(animation_frame="time")
+ with pytest.raises(ValueError, match="animation frames"):
+ subplots(fig)
+
+ def test_source_not_modified(self) -> None:
+ da = xr.DataArray([1, 2, 3], dims=["x"], name="val")
+ fig = xpx(da).line()
+ original_count = len(fig.data)
+ _ = subplots(fig, fig, cols=2)
+ assert len(fig.data) == original_count
diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py
index 7cefbb5..c526173 100644
--- a/xarray_plotly/__init__.py
+++ b/xarray_plotly/__init__.py
@@ -56,6 +56,7 @@
from xarray_plotly.figures import (
add_secondary_y,
overlay,
+ subplots,
update_traces,
)
@@ -66,6 +67,7 @@
"auto",
"config",
"overlay",
+ "subplots",
"update_traces",
"xpx",
]
diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py
index 6ed743d..8a4bf00 100644
--- a/xarray_plotly/figures.py
+++ b/xarray_plotly/figures.py
@@ -13,6 +13,188 @@
import plotly.graph_objects as go
+def _get_yaxis_title(fig: go.Figure) -> str:
+ """Extract the primary y-axis title text from a figure.
+
+ Args:
+ fig: A Plotly figure.
+
+ Returns:
+ The y-axis title text, or empty string if not set.
+ """
+ try:
+ return fig.layout.yaxis.title.text or ""
+ except AttributeError:
+ return ""
+
+
+def _ensure_legend_visibility(
+ combined: go.Figure,
+ source_figs: list[go.Figure],
+ trace_slices: list[slice],
+) -> None:
+ """Fix legend visibility on a combined figure.
+
+ Handles three problems that arise when combining Plotly Express figures:
+
+ 1. **Unnamed traces** — PX sets ``name=""`` on single-trace (no color)
+ figures. We derive a name from each source figure's y-axis title.
+ 2. **Hidden named traces** — PX sets ``showlegend=False`` on single-trace
+ figures. We ensure at least one trace per ``legendgroup`` (or each
+ ungrouped named trace) has ``showlegend=True``.
+ 3. **Duplicate legend entries** — when two source figures share the same
+ ``legendgroup`` names, we deduplicate so only the first trace per
+ group shows in the legend.
+
+ Args:
+ combined: The combined Plotly figure (mutated in place).
+ source_figs: The original source figures, in trace order.
+ trace_slices: Slices into ``combined.data`` for each source figure.
+ """
+ from collections import defaultdict
+
+ # --- Step 1: label unnamed traces from source y-axis titles -----------
+ labels = [_get_yaxis_title(f) for f in source_figs]
+
+ # If all labels are non-empty and identical, disambiguate
+ unique_labels = {lb for lb in labels if lb}
+ if len(unique_labels) == 1 and all(lb for lb in labels):
+ labels = [f"{labels[0]} ({i + 1})" for i in range(len(labels))]
+
+ for label, sl in zip(labels, trace_slices, strict=False):
+ if not label:
+ continue
+ for trace in combined.data[sl]:
+ if not getattr(trace, "name", None):
+ trace.name = label
+ trace.legendgroup = label
+
+ # --- Step 2 & 3: fix showlegend per legendgroup -----------------------
+ grouped: dict[str, list[Any]] = defaultdict(list)
+ ungrouped: list[Any] = []
+
+ for trace in combined.data:
+ lg = getattr(trace, "legendgroup", None) or ""
+ if lg:
+ grouped[lg].append(trace)
+ else:
+ ungrouped.append(trace)
+
+ for traces in grouped.values():
+ has_visible = False
+ for t in traces:
+ if has_visible:
+ # Deduplicate: only first keeps showlegend
+ t.showlegend = False
+ elif getattr(t, "name", None):
+ t.showlegend = True
+ has_visible = True
+
+ # Ungrouped traces with a name should show in the legend
+ for trace in ungrouped:
+ if getattr(trace, "name", None):
+ trace.showlegend = True
+
+ # --- Step 4: propagate style properties to animation frame traces ------
+ # When Plotly animates, frame trace data overwrites fig.data properties.
+ # PX frame traces carry name="", showlegend=False and default colors,
+ # discarding any styling the user applied via update_traces() before
+ # combining. Propagate display properties from fig.data into every frame.
+ _STYLE_ATTRS = ("name", "legendgroup", "showlegend", "marker", "line", "opacity")
+ for frame in combined.frames or []:
+ for i, frame_trace in enumerate(frame.data):
+ if i < len(combined.data):
+ src = combined.data[i]
+ for attr in _STYLE_ATTRS:
+ src_val = getattr(src, attr, None)
+ if src_val is not None:
+ setattr(frame_trace, attr, src_val)
+
+
+def _fix_animation_axis_ranges(fig: go.Figure) -> None:
+ """Set axis ranges to encompass data across all animation frames.
+
+ Plotly.js computes autorange from ``fig.data`` only and does not
+ recalculate during animation. When different frames have very different
+ data ranges (e.g. population of Brazil vs China), values can go off-screen.
+ This function computes the global min/max for each axis across all frames
+ and sets explicit ranges on the layout.
+
+ Only numeric axes are handled; categorical/date axes are left to autorange.
+
+ Args:
+ fig: A Plotly figure with animation frames (mutated in place).
+ """
+ import numpy as np
+
+ if not fig.frames:
+ return
+
+ from collections import defaultdict
+
+ # Collect numeric y-values per axis across all traces (fig.data + frames)
+ y_by_axis: dict[str, list[float]] = defaultdict(list)
+ x_by_axis: dict[str, list[float]] = defaultdict(list)
+
+ # Track which axes have bar traces (for zero-baseline clamping)
+ y_has_vbar: set[str] = set() # vertical bars → y-axis includes 0
+ x_has_hbar: set[str] = set() # horizontal bars → x-axis includes 0
+
+ for trace in _iter_all_traces(fig):
+ yaxis = getattr(trace, "yaxis", None) or "y"
+ xaxis = getattr(trace, "xaxis", None) or "x"
+
+ # Track bar orientations
+ if getattr(trace, "type", None) == "bar":
+ orientation = getattr(trace, "orientation", None) or "v"
+ if orientation == "h":
+ x_has_hbar.add(xaxis)
+ else:
+ y_has_vbar.add(yaxis)
+
+ for data_attr, axis_ref, by_axis in [
+ ("y", yaxis, y_by_axis),
+ ("x", xaxis, x_by_axis),
+ ]:
+ vals = getattr(trace, data_attr, None)
+ if vals is None:
+ continue
+ arr = np.asarray(vals)
+ # Skip datetime/timedelta — leave those axes on autorange
+ if np.issubdtype(arr.dtype, np.datetime64) or np.issubdtype(arr.dtype, np.timedelta64):
+ continue
+ try:
+ arr = arr.astype(float)
+ finite = arr[np.isfinite(arr)]
+ if len(finite):
+ by_axis[axis_ref].extend(finite.tolist())
+ except (ValueError, TypeError):
+ pass # Non-numeric (categorical) — skip
+
+ # Apply ranges to layout
+ for axis_ref, values in y_by_axis.items():
+ if not values:
+ continue
+ lo, hi = min(values), max(values)
+ if axis_ref in y_has_vbar:
+ lo = min(lo, 0.0)
+ hi = max(hi, 0.0)
+ pad = (hi - lo) * 0.05 or 1 # 5% padding
+ layout_prop = "yaxis" if axis_ref == "y" else f"yaxis{axis_ref[1:]}"
+ fig.layout[layout_prop].range = [lo - pad, hi + pad]
+
+ for axis_ref, values in x_by_axis.items():
+ if not values:
+ continue
+ lo, hi = min(values), max(values)
+ if axis_ref in x_has_hbar:
+ lo = min(lo, 0.0)
+ hi = max(hi, 0.0)
+ pad = (hi - lo) * 0.05 or 1
+ layout_prop = "xaxis" if axis_ref == "x" else f"xaxis{axis_ref[1:]}"
+ fig.layout[layout_prop].range = [lo - pad, hi + pad]
+
+
def _iter_all_traces(fig: go.Figure) -> Iterator[Any]:
"""Iterate over all traces in a figure, including animation frames.
@@ -194,17 +376,11 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
_validate_compatible_structure(base, overlay)
_validate_animation_compatibility(base, overlay)
- # Create new figure with base's layout
- combined = go.Figure(layout=copy.deepcopy(base.layout))
-
- # Add all traces from base
- for trace in base.data:
- combined.add_trace(copy.deepcopy(trace))
-
- # Add all traces from overlays
+ # Create new figure with base's layout and all traces
+ all_traces = [copy.deepcopy(t) for t in base.data]
for overlay in overlays:
- for trace in overlay.data:
- combined.add_trace(copy.deepcopy(trace))
+ all_traces.extend(copy.deepcopy(t) for t in overlay.data)
+ combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout))
# Handle animation frames
if base.frames:
@@ -213,6 +389,17 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts)
combined.frames = merged_frames
+ # Build trace slices for legend fix
+ source_figs = [base, *overlays]
+ slices: list[slice] = []
+ offset = 0
+ for fig in source_figs:
+ n = len(fig.data)
+ slices.append(slice(offset, offset + n))
+ offset += n
+
+ _ensure_legend_visibility(combined, source_figs, slices)
+ _fix_animation_axis_ranges(combined)
return combined
@@ -315,19 +502,15 @@ def add_secondary_y(
rightmost_x = max(x_for_y.values(), key=lambda x: int(x[1:]) if x != "x" else 1)
rightmost_primary_y = next(y for y, x in x_for_y.items() if x == rightmost_x)
- # Create new figure with base's layout
- combined = go.Figure(layout=copy.deepcopy(base.layout))
-
- # Add all traces from base (primary y-axis)
- for trace in base.data:
- combined.add_trace(copy.deepcopy(trace))
-
- # Add all traces from secondary, remapped to secondary y-axes
+ # Build all traces: base (primary) + secondary (remapped to secondary y-axes)
+ all_traces = [copy.deepcopy(t) for t in base.data]
for trace in secondary.data:
trace_copy = copy.deepcopy(trace)
original_yaxis = getattr(trace_copy, "yaxis", None) or "y"
trace_copy.yaxis = y_mapping[original_yaxis]
- combined.add_trace(trace_copy)
+ all_traces.append(trace_copy)
+
+ combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout))
# Get the rightmost secondary y-axis name for linking
rightmost_secondary_y = y_mapping[rightmost_primary_y]
@@ -368,6 +551,14 @@ def add_secondary_y(
merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping)
combined.frames = merged_frames
+ base_n = len(base.data)
+ sec_n = len(secondary.data)
+ _ensure_legend_visibility(
+ combined,
+ [base, secondary],
+ [slice(0, base_n), slice(base_n, base_n + sec_n)],
+ )
+ _fix_animation_axis_ranges(combined)
return combined
@@ -426,6 +617,272 @@ def _merge_secondary_y_frames(
return merged_frames
+def _get_figure_title(fig: go.Figure) -> str:
+ """Extract a display title from a figure for use as a subplot title.
+
+ Checks, in order: the figure's title, then the y-axis title.
+
+ Args:
+ fig: A Plotly figure.
+
+ Returns:
+ A title string, or empty string if nothing is set.
+ """
+ try:
+ title = fig.layout.title.text
+ if isinstance(title, str) and title:
+ return title
+ except AttributeError:
+ pass
+ return _get_yaxis_title(fig)
+
+
+def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
+ """Arrange multiple figures into a subplot grid.
+
+ Creates a new figure with each input figure placed in its own cell.
+ Figures may contain internal subplots (facets) — their axes are remapped
+ to fit within the grid cell. Subplot titles are derived from each
+ figure's title or y-axis label.
+
+ Args:
+ *figs: One or more Plotly figures to arrange.
+ cols: Number of columns in the grid. Rows are computed automatically.
+
+ Returns:
+ A new figure with subplot grid.
+
+ Raises:
+ ValueError: If no figures are provided, cols < 1, or a figure has
+ animation frames.
+
+ Example:
+ >>> import numpy as np
+ >>> import xarray as xr
+ >>> from xarray_plotly import xpx, subplots
+ >>>
+ >>> temp = xr.DataArray([20, 22, 25], dims=["time"], name="Temperature")
+ >>> rain = xr.DataArray([0, 5, 12], dims=["time"], name="Rainfall")
+ >>> fig1 = xpx(temp).line()
+ >>> fig2 = xpx(rain).bar()
+ >>> grid = subplots(fig1, fig2, cols=2)
+ """
+ import math
+
+ import plotly.graph_objects as go
+
+ if not figs:
+ raise ValueError("At least one figure is required.")
+ if cols < 1:
+ raise ValueError(f"cols must be >= 1, got {cols}.")
+
+ for i, fig in enumerate(figs):
+ if fig.frames:
+ raise ValueError(
+ f"Figure at position {i} has animation frames. "
+ "Animated figures are not supported in subplots()."
+ )
+
+ rows = math.ceil(len(figs) / cols)
+ combined = go.Figure()
+
+ # Grid spacing
+ h_gap = 0.05
+ v_gap = 0.08
+ cell_w = (1.0 - h_gap * (cols - 1)) / cols
+ cell_h = (1.0 - v_gap * (rows - 1)) / rows
+
+ next_x_num = 1
+ next_y_num = 1
+
+ for i, fig in enumerate(figs):
+ row = i // cols # 0-indexed, top to bottom
+ col = i % cols
+
+ # Cell boundaries (clamped to [0, 1])
+ cell_x0 = max(0.0, col * (cell_w + h_gap))
+ cell_x1 = min(1.0, cell_x0 + cell_w)
+ cell_y1 = min(1.0, 1.0 - row * (cell_h + v_gap)) # top-down
+ cell_y0 = max(0.0, cell_y1 - cell_h)
+
+ # Build axis remapping: old axis ref → new axis ref
+ axis_map, next_x_num, next_y_num = _remap_figure_axes(
+ fig, combined, next_x_num, next_y_num, cell_x0, cell_x1, cell_y0, cell_y1
+ )
+
+ # Add traces with remapped axis refs
+ for trace in fig.data:
+ tc = copy.deepcopy(trace)
+ old_x = getattr(tc, "xaxis", None) or "x"
+ old_y = getattr(tc, "yaxis", None) or "y"
+ tc.xaxis = axis_map[old_x]["new_x"]
+ tc.yaxis = axis_map[old_y]["new_y"]
+ combined.add_trace(tc)
+
+ # Add subplot title as annotation
+ title = _get_figure_title(fig)
+ if title:
+ combined.add_annotation(
+ text=f"{title}",
+ x=(cell_x0 + cell_x1) / 2,
+ y=cell_y1,
+ xref="paper",
+ yref="paper",
+ xanchor="center",
+ yanchor="bottom",
+ showarrow=False,
+ font={"size": 14},
+ )
+
+ return combined
+
+
+# Axis properties safe to copy between figures (display-only, not structural).
+_AXIS_PROPS_TO_COPY = (
+ "title",
+ "type",
+ "tickformat",
+ "ticksuffix",
+ "tickprefix",
+ "dtick",
+ "tick0",
+ "nticks",
+ "showgrid",
+ "gridcolor",
+ "gridwidth",
+ "autorange",
+ "range",
+ "zeroline",
+ "zerolinecolor",
+ "zerolinewidth",
+ "showticklabels",
+)
+
+
+def _axis_layout_key(ref: str) -> str:
+ """Convert axis reference to layout property name.
+
+ ``"x"`` → ``"xaxis"``, ``"x2"`` → ``"xaxis2"``,
+ ``"y"`` → ``"yaxis"``, ``"y3"`` → ``"yaxis3"``.
+ """
+ if ref in ("x", "y"):
+ return f"{ref}axis"
+ prefix = ref[0] # "x" or "y"
+ num = ref[1:]
+ return f"{prefix}axis{num}"
+
+
+def _new_axis_ref(prefix: str, num: int) -> str:
+ """Build an axis reference string. ``_new_axis_ref("x", 1)`` → ``"x"``, ``("x", 3)`` → ``"x3"``."""
+ return prefix if num == 1 else f"{prefix}{num}"
+
+
+def _remap_figure_axes(
+ fig: go.Figure,
+ combined: go.Figure,
+ next_x_num: int,
+ next_y_num: int,
+ cell_x0: float,
+ cell_x1: float,
+ cell_y0: float,
+ cell_y1: float,
+) -> tuple[dict[str, dict[str, str]], int, int]:
+ """Remap a figure's axes into a grid cell, adding axis configs to the combined layout.
+
+ Args:
+ fig: Source figure.
+ combined: Target combined figure (mutated — axis configs added to layout).
+ next_x_num: Next available x-axis number.
+ next_y_num: Next available y-axis number.
+ cell_x0, cell_x1: Horizontal cell bounds in paper coordinates.
+ cell_y0, cell_y1: Vertical cell bounds in paper coordinates.
+
+ Returns:
+ Tuple of (axis_map, next_x_num, next_y_num).
+ axis_map maps old axis refs to ``{"new_x": ...}`` or ``{"new_y": ...}``.
+ """
+ cell_w = cell_x1 - cell_x0
+ cell_h = cell_y1 - cell_y0
+ src_layout = fig.layout.to_plotly_json()
+
+ x_remap: dict[str, str] = {}
+ y_remap: dict[str, str] = {}
+
+ # Get all unique axis refs
+ x_refs: set[str] = set()
+ y_refs: set[str] = set()
+ for trace in fig.data:
+ x_refs.add(getattr(trace, "xaxis", None) or "x")
+ y_refs.add(getattr(trace, "yaxis", None) or "y")
+
+ # Remap x-axes
+ for old_xref in sorted(x_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1):
+ new_xref = _new_axis_ref("x", next_x_num)
+ x_remap[old_xref] = new_xref
+
+ src_config = src_layout.get(_axis_layout_key(old_xref), {})
+ src_domain = src_config.get("domain", [0.0, 1.0])
+ new_domain = [
+ max(0.0, cell_x0 + src_domain[0] * cell_w),
+ min(1.0, cell_x0 + src_domain[1] * cell_w),
+ ]
+
+ new_config: dict[str, Any] = {"domain": new_domain}
+ for prop in _AXIS_PROPS_TO_COPY:
+ if prop in src_config:
+ new_config[prop] = src_config[prop]
+
+ combined.layout[_axis_layout_key(new_xref)] = new_config
+ next_x_num += 1
+
+ # Remap y-axes
+ for old_yref in sorted(y_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1):
+ new_yref = _new_axis_ref("y", next_y_num)
+ y_remap[old_yref] = new_yref
+
+ src_config = src_layout.get(_axis_layout_key(old_yref), {})
+ src_domain = src_config.get("domain", [0.0, 1.0])
+ new_domain = [
+ max(0.0, cell_y0 + src_domain[0] * cell_h),
+ min(1.0, cell_y0 + src_domain[1] * cell_h),
+ ]
+
+ new_config = {"domain": new_domain}
+ for prop in _AXIS_PROPS_TO_COPY:
+ if prop in src_config:
+ new_config[prop] = src_config[prop]
+
+ combined.layout[_axis_layout_key(new_yref)] = new_config
+ next_y_num += 1
+
+ # Set anchors between paired axes
+ for trace in fig.data:
+ old_x = getattr(trace, "xaxis", None) or "x"
+ old_y = getattr(trace, "yaxis", None) or "y"
+ combined.layout[_axis_layout_key(x_remap[old_x])]["anchor"] = y_remap[old_y]
+ combined.layout[_axis_layout_key(y_remap[old_y])]["anchor"] = x_remap[old_x]
+
+ # Propagate matches relationships
+ for old_ref, new_ref in x_remap.items():
+ src_config = src_layout.get(_axis_layout_key(old_ref), {})
+ if "matches" in src_config and src_config["matches"] in x_remap:
+ combined.layout[_axis_layout_key(new_ref)]["matches"] = x_remap[src_config["matches"]]
+
+ for old_ref, new_ref in y_remap.items():
+ src_config = src_layout.get(_axis_layout_key(old_ref), {})
+ if "matches" in src_config and src_config["matches"] in y_remap:
+ combined.layout[_axis_layout_key(new_ref)]["matches"] = y_remap[src_config["matches"]]
+
+ # Build combined return mapping
+ result: dict[str, dict[str, str]] = {}
+ for old_x, new_x in x_remap.items():
+ result[old_x] = {"new_x": new_x}
+ for old_y, new_y in y_remap.items():
+ result[old_y] = {"new_y": new_y}
+
+ return result, next_x_num, next_y_num
+
+
def update_traces(
fig: go.Figure, selector: dict[str, Any] | None = None, **kwargs: Any
) -> go.Figure: