diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 014c3cc5..02934fbc 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -358,8 +358,7 @@ def _render_shapes( color_key = ( [_hex_no_alpha(x) for x in color_vector.categories.values] - if (type(color_vector) is pd.core.arrays.categorical.Categorical) - and (len(color_vector.categories.values) > 1) + if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1) else None ) @@ -854,8 +853,7 @@ def _render_points( color_key: list[str] | None = ( list(color_vector.categories.values) - if (type(color_vector) is pd.core.arrays.categorical.Categorical) - and (len(color_vector.categories.values) > 1) + if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1) else None ) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 9d3b2954..ae55a9c8 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1120,7 +1120,10 @@ def _set_color_source_vec( raise ValueError("Unable to create color palette.") # do not rename categories, as colors need not be unique - color_vector = color_source_vector.map(color_mapping) + # pd.Categorical.map() demotes to object dtype when mapped values aren't unique + # (e.g. two categories share a color). Wrapping back in pd.Categorical ensures + # downstream consumers always receive a Categorical for categorical data. + color_vector = pd.Categorical(color_source_vector.map(color_mapping, na_action="ignore")) return color_source_vector, color_vector, True @@ -1146,7 +1149,7 @@ def _map_color_seg( ) -> ArrayLike: cell_id = np.array(cell_id) - if pd.api.types.is_categorical_dtype(color_vector.dtype): + if isinstance(color_vector.dtype, pd.CategoricalDtype): # Case A: users wants to plot a categorical column if np.any(color_source_vector.isna()): cell_id[color_source_vector.isna()] = 0