From 3acd6482f990262e8d3d4f5088dddb7738565428 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 10:08:15 +0100 Subject: [PATCH 01/29] Bump the github-actions group with 2 updates (#398) --- .github/workflows/build-ultraplot.yml | 2 +- .github/workflows/publish-pypi.yml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-ultraplot.yml b/.github/workflows/build-ultraplot.yml index ab19ab15c..38789f981 100644 --- a/.github/workflows/build-ultraplot.yml +++ b/.github/workflows/build-ultraplot.yml @@ -98,7 +98,7 @@ jobs: # Return the html output of the comparison even if failed - name: Upload comparison failures if: always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: failed-comparisons-${{ inputs.python-version }}-${{ inputs.matplotlib-version }}-${{ github.sha }} path: results/* diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 1eda57ccb..1cd1c9e14 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -54,7 +54,7 @@ jobs: shell: bash - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist/* @@ -73,7 +73,7 @@ jobs: contents: read steps: - name: Download artifacts - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v6 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist @@ -105,7 +105,7 @@ jobs: contents: read steps: - name: Download artifacts - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v6 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist From 4b438aa0ec0e4f2819c6b16d4f416d1bda235198 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Sun, 2 Nov 2025 11:52:06 +0100 Subject: [PATCH 02/29] add s and unittest (#400) --- ultraplot/axes/plot.py | 2 +- ultraplot/tests/test_statistical_plotting.py | 22 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index dc7ff4f27..9364fd0bd 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -2087,7 +2087,7 @@ def _add_error_bars( ): # ugly kludge to check for shading if all(_ is None for _ in (bardata, barstds, barpctiles)): barstds, barpctiles = default_barstds, default_barpctiles - if all(_ is None for _ in (boxdata, boxstds, boxpctile)): + if all(_ is None for _ in (boxdata, boxstds, boxpctiles)): boxstds, boxpctiles = default_boxstds, default_boxpctiles showbars = any( _ is not None and _ is not False for _ in (barstds, barpctiles, bardata) diff --git a/ultraplot/tests/test_statistical_plotting.py b/ultraplot/tests/test_statistical_plotting.py index c65f25245..d1aff89c3 100644 --- a/ultraplot/tests/test_statistical_plotting.py +++ b/ultraplot/tests/test_statistical_plotting.py @@ -71,3 +71,25 @@ def test_panel_dist(rng): px.hist(x, bins, color=color, fill=True, ec="k") px.format(grid=False, ylocator=[], title=title, titleloc="l") return fig + + +@pytest.mark.mpl_image_compare +def test_input_violin_box_options(): + """ + Test various box options in violin plots. + """ + data = np.array([0, 1, 2, 3]).reshape(-1, 1) + + fig, axes = uplt.subplots(ncols=4) + axes[0].bar(data, median=True, boxpctiles=True, bars=False) + axes[0].format(title="boxpctiles") + + axes[1].bar(data, median=True, boxpctile=True, bars=False) + axes[1].format(title="boxpctile") + + axes[2].bar(data, median=True, boxstd=True, bars=False) + axes[2].format(title="boxstd") + + axes[3].bar(data, median=True, boxstds=True, bars=False) + axes[3].format(title="boxstds") + return fig From 84b0c6c321abe6c16fd3f6e4abe8c960d7c15922 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Sun, 16 Nov 2025 22:22:59 +1000 Subject: [PATCH 03/29] redo with new ticker (#411) --- ultraplot/axes/base.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index b7e6631be..e16bb63e3 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -2626,9 +2626,19 @@ def _unshare(self, *, which: str): setattr(sibling, f"_share{which}", None) this_ax = getattr(self, f"{which}axis") sib_ax = getattr(sibling, f"{which}axis") - # Reset formatters - this_ax.major = copy.deepcopy(this_ax.major) - this_ax.minor = copy.deepcopy(this_ax.minor) + # Reset formatters by creating new Ticker objects. + # A deepcopy can trigger redraws. + new_major = maxis.Ticker() + if this_ax.major: + new_major.locator = copy.copy(this_ax.major.locator) + new_major.formatter = copy.copy(this_ax.major.formatter) + this_ax.major = new_major + + new_minor = maxis.Ticker() + if this_ax.minor: + new_minor.locator = copy.copy(this_ax.minor.locator) + new_minor.formatter = copy.copy(this_ax.minor.formatter) + this_ax.minor = new_minor def _sharex_setup(self, sharex, **kwargs): """ From fa01291299f28954058f20edeabb6675aa9e8a46 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Tue, 18 Nov 2025 11:28:30 +1000 Subject: [PATCH 04/29] Hotfix: bar labels cause limit to reset for unaffected axis. (#413) --- ultraplot/axes/plot.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 9364fd0bd..1eb92d68a 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -4705,6 +4705,7 @@ def _add_bar_labels( # Find the maximum extent of text + bar position max_extent = current_lim[1] # Start with current upper limit + w = 0 for label, bar in zip(bar_labels, container): # Get text bounding box bbox = label.get_window_extent(renderer=self.figure.canvas.get_renderer()) @@ -4715,21 +4716,25 @@ def _add_bar_labels( bar_end = bar.get_width() + bar.get_x() text_end = bar_end + bbox_data.width max_extent = max(max_extent, text_end) + w = max(w, bar.get_height()) else: # For vertical bars, check if text extends beyond top edge bar_end = bar.get_height() + bar.get_y() text_end = bar_end + bbox_data.height max_extent = max(max_extent, text_end) + w = max(w, bar.get_width()) # Only adjust limits if text extends beyond current range if max_extent > current_lim[1]: padding = (max_extent - current_lim[1]) * 1.25 # Add a bit of padding new_lim = (current_lim[0], max_extent + padding) getattr(self, f"set_{which}lim")(new_lim) + lim = [getattr(self.dataLim, f"{other_which}{idx}") for idx in range(0, 2)] + lim = (lim[0] - w / 4, lim[1] + w / 4) - # Keep the other axis unchanged - getattr(self, f"set_{other_which}lim")(other_lim) - + current_lim = getattr(self, f"get_{other_which}lim")() + new_lim = (min(lim[0], current_lim[0]), max(lim[1], current_lim[1])) + getattr(self, f"set_{other_which}lim")(new_lim) return bar_labels @inputs._preprocess_or_redirect("x", "height", "width", "bottom") From cad96d0f435df7733d7474bd5d4a87f44f74ebdb Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 19 Nov 2025 06:53:59 -0600 Subject: [PATCH 05/29] fix: change default `reduce_C_function` to `np.sum` for `hexbin` (#408) * fix: change default reduce_C_function to np.sum for hexbin Updated default behavior for weights/C to compute total instead of average. * test: add a test --------- Co-authored-by: Casper van Elteren --- ultraplot/axes/plot.py | 5 +++++ ultraplot/tests/test_plot.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 1eb92d68a..074dc3a54 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -5257,6 +5257,11 @@ def hexbin(self, x, y, weights, **kwargs): center_levels=center_levels, **kw, ) + # Change the default behavior for weights/C to compute + # the total of the weights, not their average. + reduce_C_function = kw.get("reduce_C_function", None) + if reduce_C_function is None: + kw["reduce_C_function"] = np.sum norm = kw.get("norm", None) if norm is not None and not isinstance(norm, pcolors.DiscreteNorm): norm.vmin = norm.vmax = None # remove nonsense values diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index e3eb9455d..145de8b9e 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -617,3 +617,38 @@ def test_curved_quiver_color_and_cmap(rng, cmap): fig, ax = uplt.subplots() ax.curved_quiver(X, Y, U, V, color=color, cmap=cmap) return fig + + +def test_histogram_norms(): + """ + Check that all histograms-like plotting functions + use the sum of the weights. + """ + rng = np.random.default_rng(seed=100) + x, y = rng.normal(size=(2, 100)) + w = rng.uniform(size=100) + + fig, axs = uplt.subplots() + _, _, bars = axs.hist(x, weights=w, bins=5) + tot_weights = np.sum([bar.get_height() for bar in bars]) + np.testing.assert_allclose(tot_weights, np.sum(w)) + + fig, axs = uplt.subplots() + _, _, _, qm = axs.hist2d(x, y, weights=w, bins=5) + tot_weights = np.sum(qm.get_array()) + np.testing.assert_allclose(tot_weights, np.sum(w)) + + fig, axs = uplt.subplots() + pc = axs.hexbin(x, y, weights=w, gridsize=5) + tot_weights = np.sum(pc.get_array()) + np.testing.assert_allclose(tot_weights, np.sum(w)) + + # check that a different reduce_C_function produces + # a different result + fig, axs = uplt.subplots() + pc = axs.hexbin(x, y, weights=w, gridsize=5, reduce_C_function=np.max) + tot_weights = np.sum(pc.get_array()) + # check they are not equal and that the different is not + # due to floating point errors + assert tot_weights != np.sum(w) + assert not np.allclose(tot_weights, np.sum(w)) From 477e18791a469f2db75396c89ec34ef84b398433 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Fri, 21 Nov 2025 00:52:16 +1000 Subject: [PATCH 06/29] Add external context mode for axes (#406) * add seaborn context processing * rm debug * add unittest * resolve iterable * relax legend filter * add seaborn import * add more unittests * add ctx texts * implement mark external and context managing * fix test * refactor classes for clarity * update tests * more fixes * more tests * minor fix * minor fix * fix for mpl 3.9 * remove stack frame * adjust and remove unecessary tests * more fixes * add external to pass test * restore test * rm dup * finalize docstring * remove fallback * Apply suggestion from @beckermr * Apply suggestion from @beckermr * fix bar and test --------- Co-authored-by: Matthew R. Becker --- ultraplot/axes/base.py | 85 +++++++++-- ultraplot/axes/plot.py | 214 +++++++++++++++++++++------- ultraplot/tests/test_1dplots.py | 42 ++++++ ultraplot/tests/test_colorbar.py | 41 ++++++ ultraplot/tests/test_integration.py | 104 ++++++++++++-- ultraplot/tests/test_legend.py | 101 ++++++++++++- ultraplot/tests/test_plot.py | 84 ++++++++++- 7 files changed, 596 insertions(+), 75 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index e16bb63e3..03bea4fdc 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -6,10 +6,11 @@ import copy import inspect import re +import sys import types -from numbers import Integral, Number -from typing import Union, Iterable, MutableMapping, Optional, Tuple from collections.abc import Iterable as IterableType +from numbers import Integral, Number +from typing import Iterable, MutableMapping, Optional, Tuple, Union try: # From python 3.12 @@ -34,12 +35,11 @@ from matplotlib import cbook from packaging import version -from .. import legend as plegend from .. import colors as pcolors from .. import constructor +from .. import legend as plegend from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 from ..internals import ( _kwargs_to_args, _not_none, @@ -51,6 +51,7 @@ _version_mpl, docstring, guides, + ic, # noqa: F401 labels, rcsetup, warnings, @@ -700,7 +701,52 @@ def __call__(self, ax, renderer): # noqa: U100 return bbox -class Axes(maxes.Axes): +class _ExternalModeMixin: + """ + Mixin providing explicit external-mode control and a context manager. + """ + + def set_external(self, value=True): + """ + Set explicit external-mode override for this axes. + + value: + - True: force external behavior (defer on-the-fly guides, etc.) + - False: force UltraPlot behavior + """ + if value not in (True, False): + raise ValueError("set_external expects True or False") + setattr(self, "_integration_external", value) + return self + + class _ExternalContext: + def __init__(self, ax, value=True): + self._ax = ax + self._value = True if value is None else value + self._prev = getattr(ax, "_integration_external", None) + + def __enter__(self): + self._ax._integration_external = self._value + return self._ax + + def __exit__(self, exc_type, exc, tb): + self._ax._integration_external = self._prev + + def external(self, value=True): + """ + Context manager toggling external mode during the block. + """ + return _ExternalModeMixin._ExternalContext(self, value) + + def _in_external_context(self): + """ + Return True if UltraPlot helper behaviors should be suppressed. + """ + mode = getattr(self, "_integration_external", None) + return mode is True + + +class Axes(_ExternalModeMixin, maxes.Axes): """ The lowest-level `~matplotlib.axes.Axes` subclass used by ultraplot. Implements basic universal features. @@ -822,6 +868,7 @@ def __init__(self, *args, **kwargs): self._panel_sharey_group = False # see _apply_auto_share self._panel_side = None self._tight_bbox = None # bounding boxes are saved + self._integration_external = None # explicit external-mode override (None=auto) self.xaxis.isDefault_minloc = True # ensure enabled at start (needed for dual) self.yaxis.isDefault_minloc = True @@ -1739,6 +1786,7 @@ def _get_legend_handles(self, handler_map=None): handler_map_full = plegend.Legend.get_default_handler_map() handler_map_full = handler_map_full.copy() handler_map_full.update(handler_map or {}) + # Prefer synthetic tagging to exclude helper artists; see _ultraplot_synthetic flag on artists. for ax in axs: for attr in ("lines", "patches", "collections", "containers"): for handle in getattr(ax, attr, []): # guard against API changes @@ -1746,7 +1794,12 @@ def _get_legend_handles(self, handler_map=None): handler = plegend.Legend.get_legend_handler( handler_map_full, handle ) # noqa: E501 - if handler and label and label[0] != "_": + if ( + handler + and label + and label[0] != "_" + and not getattr(handle, "_ultraplot_synthetic", False) + ): handles.append(handle) return handles @@ -1897,11 +1950,17 @@ def _update_guide( if legend: align = legend_kw.pop("align", None) queue = legend_kw.pop("queue", queue_legend) - self.legend(objs, loc=legend, align=align, queue=queue, **legend_kw) + # Avoid immediate legend creation in external context + if not self._in_external_context(): + self.legend(objs, loc=legend, align=align, queue=queue, **legend_kw) if colorbar: align = colorbar_kw.pop("align", None) queue = colorbar_kw.pop("queue", queue_colorbar) - self.colorbar(objs, loc=colorbar, align=align, queue=queue, **colorbar_kw) + # Avoid immediate colorbar creation in external context + if not self._in_external_context(): + self.colorbar( + objs, loc=colorbar, align=align, queue=queue, **colorbar_kw + ) @staticmethod def _parse_frame(guide, fancybox=None, shadow=None, **kwargs): @@ -2423,6 +2482,8 @@ def _legend_label(*objs): # noqa: E301 labs = [] for obj in objs: if hasattr(obj, "get_label"): # e.g. silent list + if getattr(obj, "_ultraplot_synthetic", False): + continue lab = obj.get_label() if lab is not None and not str(lab).startswith("_"): labs.append(lab) @@ -2453,10 +2514,15 @@ def _legend_tuple(*objs): # noqa: E306 if hs: handles.extend(hs) elif obj: # fallback to first element - handles.append(obj[0]) + # Skip synthetic helpers and fill_between collections + if not getattr(obj[0], "_ultraplot_synthetic", False): + handles.append(obj[0]) else: handles.append(obj) elif hasattr(obj, "get_label"): + # Skip synthetic helpers and fill_between collections + if getattr(obj, "_ultraplot_synthetic", False): + continue handles.append(obj) else: warnings._warn_ultraplot(f"Ignoring invalid legend handle {obj!r}.") @@ -3332,6 +3398,7 @@ def _label_key(self, side: str) -> str: labelright/labelleft respectively. """ from packaging import version + from ..internals import _version_mpl # TODO: internal deprecation warning when we drop 3.9, we need to remove this diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 074dc3a54..526e6ffac 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -8,50 +8,46 @@ import itertools import re import sys +from collections.abc import Callable, Iterable from numbers import Integral, Number +from typing import Any, Iterable, Optional, Union -from typing import Any, Union, Iterable, Optional - -from collections.abc import Callable -from collections.abc import Iterable - -from ..utils import units +import matplotlib as mpl import matplotlib.artist as martist import matplotlib.axes as maxes import matplotlib.cbook as cbook import matplotlib.cm as mcm import matplotlib.collections as mcollections import matplotlib.colors as mcolors -import matplotlib.contour as mcontour import matplotlib.container as mcontainer +import matplotlib.contour as mcontour import matplotlib.image as mimage import matplotlib.lines as mlines import matplotlib.patches as mpatches -import matplotlib.ticker as mticker import matplotlib.pyplot as mplt -import matplotlib as mpl -from packaging import version +import matplotlib.ticker as mticker import numpy as np -from typing import Optional, Union, Any import numpy.ma as ma +from packaging import version from .. import colors as pcolors from .. import constructor, utils from ..config import rc -from ..internals import ic # noqa: F401 from ..internals import ( _get_aliases, _not_none, _pop_kwargs, _pop_params, _pop_props, + _version_mpl, context, docstring, guides, + ic, # noqa: F401 inputs, warnings, - _version_mpl, ) +from ..utils import units from . import base try: @@ -1512,25 +1508,6 @@ def _parse_vert( return kwargs -def _inside_seaborn_call(): - """ - Try to detect `seaborn` calls to `scatter` and `bar` and then automatically - apply `absolute_size` and `absolute_width`. - """ - frame = sys._getframe() - absolute_names = ( - "seaborn.distributions", - "seaborn.categorical", - "seaborn.relational", - "seaborn.regression", - ) - while frame is not None: - if frame.f_globals.get("__name__", "") in absolute_names: - return True - frame = frame.f_back - return False - - class PlotAxes(base.Axes): """ The second lowest-level `~matplotlib.axes.Axes` subclass used by ultraplot. @@ -1566,7 +1543,7 @@ def curved_quiver( The implementation of this function is based on the `dfm_tools` repository. Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py """ - from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet + from .plot_types.curved_quiver import CurvedQuiverSet, CurvedQuiverSolver # Parse inputs arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"]) @@ -2237,6 +2214,7 @@ def _add_error_shading( # Draw dark and light shading from distributions or explicit errdata eobjs = [] fill = self.fill_between if vert else self.fill_betweenx + if drawfade: edata, label = inputs._dist_range( y, @@ -2250,7 +2228,29 @@ def _add_error_shading( absolute=True, ) if edata is not None: - eobj = fill(x, *edata, label=label, **fadeprops) + synthetic = False + eff_label = label + if self._in_external_context() and ( + eff_label is None or str(eff_label) in ("y", "ymin", "ymax") + ): + eff_label = "_ultraplot_fade" + synthetic = True + + eobj = fill(x, *edata, label=eff_label, **fadeprops) + if synthetic: + try: + setattr(eobj, "_ultraplot_synthetic", True) + if hasattr(eobj, "set_label"): + eobj.set_label("_ultraplot_fade") + except Exception: + pass + for _obj in guides._iter_iterables(eobj): + try: + setattr(_obj, "_ultraplot_synthetic", True) + if hasattr(_obj, "set_label"): + _obj.set_label("_ultraplot_fade") + except Exception: + pass eobjs.append(eobj) if drawshade: edata, label = inputs._dist_range( @@ -2265,7 +2265,29 @@ def _add_error_shading( absolute=True, ) if edata is not None: - eobj = fill(x, *edata, label=label, **shadeprops) + synthetic = False + eff_label = label + if self._in_external_context() and ( + eff_label is None or str(eff_label) in ("y", "ymin", "ymax") + ): + eff_label = "_ultraplot_shade" + synthetic = True + + eobj = fill(x, *edata, label=eff_label, **shadeprops) + if synthetic: + try: + setattr(eobj, "_ultraplot_synthetic", True) + if hasattr(eobj, "set_label"): + eobj.set_label("_ultraplot_shade") + except Exception: + pass + for _obj in guides._iter_iterables(eobj): + try: + setattr(_obj, "_ultraplot_synthetic", True) + if hasattr(_obj, "set_label"): + _obj.set_label("_ultraplot_shade") + except Exception: + pass eobjs.append(eobj) kwargs["distribution"] = distribution @@ -2547,6 +2569,19 @@ def _parse_1d_format( colorbar_kw_labels = _not_none( kwargs.get("colorbar_kw", {}).pop("values", None), ) + # Track whether the user explicitly provided labels/values so we can + # preserve them even when autolabels is disabled. + _user_labels_explicit = any( + v is not None + for v in ( + label, + labels, + value, + values, + legend_kw_labels, + colorbar_kw_labels, + ) + ) labels = _not_none( label=label, @@ -2586,9 +2621,9 @@ def _parse_1d_format( # Apply the labels or values if labels is not None: - if autovalues: + if autovalues or (value is not None or values is not None): kwargs["values"] = inputs._to_numpy_array(labels) - elif autolabels: + elif autolabels or _user_labels_explicit: kwargs["labels"] = inputs._to_numpy_array(labels) # Apply title for legend or colorbar that uses the labels or values @@ -3054,7 +3089,9 @@ def _parse_cycle( resolved_cycle = constructor.Cycle(cycle, **cycle_kw) case str() if cycle.lower() == "none": resolved_cycle = None - case str() | int() | Iterable(): + case str() | int(): + resolved_cycle = constructor.Cycle(cycle, **cycle_kw) + case _ if isinstance(cycle, Iterable): resolved_cycle = constructor.Cycle(cycle, **cycle_kw) case _: resolved_cycle = None @@ -3626,6 +3663,9 @@ def _apply_plot(self, *pairs, vert=True, **kwargs): objs, xsides = [], [] kws = kwargs.copy() kws.update(_pop_props(kws, "line")) + # Disable auto label inference when in external context + if self._in_external_context(): + kws["autolabels"] = False kws, extents = self._inbounds_extent(**kws) for xs, ys, fmt in self._iter_arg_pairs(*pairs): xs, ys, kw = self._parse_1d_args(xs, ys, vert=vert, **kws) @@ -3775,7 +3815,7 @@ def _apply_beeswarm( orientation: str = "horizontal", n_bins: int = 50, **kwargs, - ) -> "Collection": + ) -> mcollections.Collection: # Parse input parameters ss, _ = self._parse_markersize(ss, **kwargs) @@ -4237,7 +4277,7 @@ def _parse_markersize( if s is not None: s = inputs._to_numpy_array(s) if absolute_size is None: - absolute_size = s.size == 1 or _inside_seaborn_call() + absolute_size = s.size == 1 if not absolute_size or smin is not None or smax is not None: smin = _not_none(smin, 1) smax = _not_none(smax, rc["lines.markersize"] ** (1, 2)[area_size]) @@ -4362,8 +4402,45 @@ def _apply_fill( stacked=None, **kwargs, ): - """ - Apply area shading. + """Apply area shading using `fill_between` or `fill_betweenx`. + + This is the internal implementation for `fill_between`, `fill_betweenx`, + `area`, and `areax`. + + Parameters + ---------- + xs, ys1, ys2 : array-like + The x and y coordinates for the shaded regions. + where : array-like, optional + A boolean mask for the points that should be shaded. + vert : bool, optional + The orientation of the shading. If `True` (default), `fill_between` + is used. If `False`, `fill_betweenx` is used. + negpos : bool, optional + Whether to use different colors for positive and negative shades. + stack : bool, optional + Whether to stack shaded regions. + **kwargs + Additional keyword arguments passed to the matplotlib fill function. + + Notes + ----- + Special handling for plots from external packages (e.g., seaborn): + + When this method is used in a context where plots are generated by + an external library like seaborn, it tags the resulting polygons + (e.g., confidence intervals) as "synthetic". This is done unless a + user explicitly provides a label. + + Synthetic artists are marked with `_ultraplot_synthetic=True` and given + a label starting with an underscore (e.g., `_ultraplot_fill`). This + prevents them from being automatically included in legends, keeping the + legend clean and focused on user-specified elements. + + Seaborn internally generates tags like "y", "ymin", and "ymax" for + vertical fills, and "x", "xmin", "xmax" for horizontal fills. UltraPlot + recognizes these and treats them as synthetic unless a different label + is provided. """ # Parse input arguments kw = kwargs.copy() @@ -4373,34 +4450,73 @@ def _apply_fill( stack = _not_none(stack=stack, stacked=stacked) xs, ys1, ys2, kw = self._parse_1d_args(xs, ys1, ys2, vert=vert, **kw) edgefix_kw = _pop_params(kw, self._fix_patch_edges) + guide_kw = _pop_params(kw, self._update_guide) + + # External override only; no seaborn-based tagging - # Draw patches with default edge width zero + # Draw patches y0 = 0 objs, xsides, ysides = [], [], [] - guide_kw = _pop_params(kw, self._update_guide) for _, n, x, y1, y2, w, kw in self._iter_arg_cols(xs, ys1, ys2, where, **kw): kw = self._parse_cycle(n, **kw) + + # If stacking requested, adjust y arrays if stack: - y1 = y1 + y0 # avoid in-place modification + y1 = y1 + y0 y2 = y2 + y0 - y0 = y0 + y2 - y1 # irrelevant that we added y0 to both - if negpos: # NOTE: if user passes 'where' will issue a warning + y0 = y0 + y2 - y1 + + # External override: if in external mode and no explicit label was provided, + # mark fill as synthetic so it is ignored by legend parsing unless explicitly labeled. + synthetic = False + if self._in_external_context() and ( + kw.get("label", None) is None + or str(kw.get("label")) in ("y", "ymin", "ymax") + ): + kw["label"] = "_ultraplot_fill" + synthetic = True + + # Draw object (negpos splits into two silent_list items) + if negpos: obj = self._call_negpos(name, x, y1, y2, where=w, use_where=True, **kw) else: obj = self._call_native(name, x, y1, y2, where=w, **kw) + + if synthetic: + try: + setattr(obj, "_ultraplot_synthetic", True) + if hasattr(obj, "set_label"): + obj.set_label("_ultraplot_fill") + except Exception: + pass + for art in guides._iter_iterables(obj): + try: + setattr(art, "_ultraplot_synthetic", True) + if hasattr(art, "set_label"): + art.set_label("_ultraplot_fill") + except Exception: + pass + + # No synthetic tagging or seaborn-based label overrides + + # Patch edge fixes self._fix_patch_edges(obj, **edgefix_kw, **kw) + + # Track sides for sticky edges xsides.append(x) for y in (y1, y2): self._inbounds_xylim(extents, x, y, vert=vert) - if y.size == 1: # add sticky edges if bounds are scalar + if y.size == 1: ysides.append(y) objs.append(obj) + # Draw guide and add sticky edges # Draw guide and add sticky edges self._update_guide(objs, **guide_kw) for axis, sides in zip("xy" if vert else "yx", (xsides, ysides)): self._fix_sticky_edges(objs, axis, *sides) return objs[0] if len(objs) == 1 else cbook.silent_list("PolyCollection", objs) + return objs[0] if len(objs) == 1 else cbook.silent_list("PolyCollection", objs) @docstring._snippet_manager def area(self, *args, **kwargs): @@ -4621,7 +4737,7 @@ def _apply_bar( xs, hs, kw = self._parse_1d_args(xs, hs, orientation=orientation, **kw) edgefix_kw = _pop_params(kw, self._fix_patch_edges) if absolute_width is None: - absolute_width = _inside_seaborn_call() + absolute_width = False or self._in_external_context() # Call func after converting bar width b0 = 0 diff --git a/ultraplot/tests/test_1dplots.py b/ultraplot/tests/test_1dplots.py index eee2178bb..50bfdc75b 100644 --- a/ultraplot/tests/test_1dplots.py +++ b/ultraplot/tests/test_1dplots.py @@ -5,8 +5,50 @@ import numpy as np import numpy.ma as ma import pandas as pd +import pytest import ultraplot as uplt + + +def test_bar_relative_width_by_default_external_and_internal(): + """ + Bars use relative widths by default regardless of external mode. + """ + x = [0, 10] + h = [1, 2] + + # Internal (external=False): relative width scales with step size + fig, ax = uplt.subplots() + ax.set_external(False) + bars_int = ax.bar(x, h) + w_int = [r.get_width() for r in bars_int.patches] + + # External (external=True): same default relative behavior + fig, ax = uplt.subplots() + ax.set_external(True) + bars_ext = ax.bar(x, h) + w_ext = [r.get_width() for r in bars_ext.patches] + + # With step=10, expect ~ 0.8 * 10 = 8 + assert pytest.approx(w_int[0], rel=1e-6) == 8.0 + assert pytest.approx(w_ext[0], rel=1e-6) == 0.8 + + +def test_bar_absolute_width_manual_override(): + """ + Users can force absolute width by passing absolute_width=True. + """ + x = [0, 10] + h = [1, 2] + + fig, ax = uplt.subplots() + bars_abs = ax.bar(x, h, absolute_width=True) + w_abs = [r.get_width() for r in bars_abs.patches] + + # Absolute width should be the raw width (default 0.8) in data units + assert pytest.approx(w_abs[0], rel=1e-6) == 0.8 + + import pytest diff --git a/ultraplot/tests/test_colorbar.py b/ultraplot/tests/test_colorbar.py index f16a6f13a..b4e42eb40 100644 --- a/ultraplot/tests/test_colorbar.py +++ b/ultraplot/tests/test_colorbar.py @@ -4,7 +4,48 @@ """ import numpy as np import pytest + import ultraplot as uplt + + +def test_colorbar_defers_external_mode(): + """ + External mode should defer on-the-fly colorbar creation until explicitly requested. + """ + import numpy as np + + fig, ax = uplt.subplots() + ax.set_external(True) + m = ax.pcolor(np.random.random((5, 5)), colorbar="b") + + # No colorbar should have been registered/created yet + assert isinstance(ax[0]._colorbar_dict, dict) + assert len(ax[0]._colorbar_dict) == 0 + + # Explicit colorbar creation should register the colorbar at the requested loc + cb = ax.colorbar(m, loc="b") + assert ("bottom", "center") in ax[0]._colorbar_dict + assert ax[0]._colorbar_dict[("bottom", "center")] is cb + + +def test_explicit_legend_with_handles_under_external_mode(): + """ + Under external mode, legend auto-creation is deferred. Passing explicit handles + to legend() must work immediately. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1], label="LegendLabel", legend="b") + + # No legend queued/created yet + assert ("bottom", "center") not in ax[0]._legend_dict + + # Explicit legend with handle should contain our label + leg = ax.legend(h, loc="b") + labels = [t.get_text() for t in leg.get_texts()] + assert "LegendLabel" in labels + + from itertools import product diff --git a/ultraplot/tests/test_integration.py b/ultraplot/tests/test_integration.py index fc1d48b90..7429fafc0 100644 --- a/ultraplot/tests/test_integration.py +++ b/ultraplot/tests/test_integration.py @@ -2,10 +2,73 @@ """ Test xarray, pandas, pint, seaborn integration. """ -import numpy as np, pandas as pd, seaborn as sns -import xarray as xr -import ultraplot as uplt, pytest +import numpy as np +import pandas as pd import pint +import pytest +import seaborn as sns +import xarray as xr + +import ultraplot as uplt + + +def test_seaborn_helpers_filtered_from_legend(): + """ + Seaborn-generated helper artists (e.g., CI bands) must be synthetic-tagged and + filtered out of legends so that only hue categories appear. + """ + fig, ax = uplt.subplots() + + # Create simple dataset with two hue levels + df = pd.DataFrame( + { + "x": np.concatenate([np.arange(10)] * 2), + "y": np.concatenate([np.arange(10), np.arange(10) + 1]), + "hue": ["h1"] * 10 + ["h2"] * 10, + } + ) + + # Use explicit external mode to engage UL's integration behavior for helper artists + with ax.external(): + sns.lineplot(data=df, x="x", y="y", hue="hue", ax=ax) + + # Explicitly create legend and verify labels + leg = ax.legend() + labels = {t.get_text() for t in leg.get_texts()} + + # Only hue labels should be present + assert {"h1", "h2"}.issubset(labels) + + # Spurious or synthetic labels must not appear + for bad in ( + "y", + "ymin", + "ymax", + "_ultraplot_fill", + "_ultraplot_shade", + "_ultraplot_fade", + ): + assert bad not in labels + + +def test_user_labeled_shading_appears_in_legend(): + """ + User-labeled shading (fill_between) must appear in legend even after seaborn plotting. + """ + fig, ax = uplt.subplots() + + # Seaborn plot first (to ensure seaborn context was present earlier) + df = pd.DataFrame({"x": np.arange(10), "y": np.arange(10)}) + sns.lineplot(data=df, x="x", y="y", ax=ax, label="line") + + # Add explicit user-labeled shading on the same axes + x = np.arange(10) + ax.fill_between(x, x - 0.2, x + 0.2, alpha=0.2, label="CI band") + + # Legend must include both the seaborn line label and our shaded band label + leg = ax.legend() + labels = {t.get_text() for t in leg.get_texts()} + assert "CI band" in labels @pytest.mark.mpl_image_compare @@ -96,18 +159,35 @@ def test_seaborn_swarmplot(): @pytest.mark.mpl_image_compare def test_seaborn_hist(rng): """ - Test seaborn histograms. + Test seaborn histograms (smoke test using external mode contexts). """ fig, axs = uplt.subplots(ncols=2, nrows=2) - sns.histplot(rng.normal(size=100), ax=axs[0]) - sns.kdeplot(x=rng.random(100), y=rng.random(100), ax=axs[1]) + + with axs[0].external(): + sns.histplot(rng.normal(size=100), ax=axs[0]) + + with axs[1].external(): + sns.kdeplot(x=rng.random(100), y=rng.random(100), ax=axs[1]) + penguins = sns.load_dataset("penguins") - sns.histplot( - data=penguins, x="flipper_length_mm", hue="species", multiple="stack", ax=axs[2] - ) - sns.kdeplot( - data=penguins, x="flipper_length_mm", hue="species", multiple="stack", ax=axs[3] - ) + + with axs[2].external(): + sns.histplot( + data=penguins, + x="flipper_length_mm", + hue="species", + multiple="stack", + ax=axs[2], + ) + + with axs[3].external(): + sns.kdeplot( + data=penguins, + x="flipper_length_mm", + hue="species", + multiple="stack", + ax=axs[3], + ) return fig diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 096b10729..dd23c5c18 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -2,7 +2,11 @@ """ Test legends. """ -import numpy as np, pandas as pd, ultraplot as uplt, pytest +import numpy as np +import pandas as pd +import pytest + +import ultraplot as uplt @pytest.mark.mpl_image_compare @@ -219,3 +223,98 @@ def test_sync_label_dict(rng): 0 ]._legend_dict, "Old legend not removed from dict" uplt.close(fig) + + +def test_external_mode_defers_on_the_fly_legend(): + """ + External mode should defer on-the-fly legend creation until explicitly requested. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1], label="a", legend="b") + + # No legend should have been created yet + assert getattr(ax[0], "legend_", None) is None + + # Explicit legend creation should include the plotted label + leg = ax.legend(h, loc="b") + labels = [t.get_text() for t in leg.get_texts()] + assert "a" in labels + uplt.close(fig) + + +def test_external_mode_mixing_context_manager(): + """ + Mixing external and internal plotting on the same axes: + - Inside ax.external(): on-the-fly legend is deferred + - Outside: UltraPlot-native plotting resumes as normal + - Final explicit ax.legend() consolidates both kinds of artists + """ + fig, ax = uplt.subplots() + + with ax.external(): + (ext,) = ax.plot([0, 1], label="ext", legend="b") # deferred + + (intr,) = ax.line([0, 1], label="int") # normal UL behavior + + leg = ax.legend([ext, intr], loc="b") + labels = {t.get_text() for t in leg.get_texts()} + assert {"ext", "int"}.issubset(labels) + uplt.close(fig) + + +def test_external_mode_toggle_enables_auto(): + """ + Toggling external mode back off should resume on-the-fly guide creation. + """ + fig, ax = uplt.subplots() + + ax.set_external(True) + (ha,) = ax.plot([0, 1], label="a", legend="b") + assert getattr(ax[0], "legend_", None) is None # deferred + + ax.set_external(False) + (hb,) = ax.plot([0, 1], label="b", legend="b") + # Now legend is queued for creation; verify it is registered in the outer legend dict + assert ("bottom", "center") in ax[0]._legend_dict + + # Ensure final legend contains both entries + leg = ax.legend([ha, hb], loc="b") + labels = {t.get_text() for t in leg.get_texts()} + assert {"a", "b"}.issubset(labels) + uplt.close(fig) + + +def test_synthetic_handles_filtered(): + """ + Synthetic-tagged helper artists must be ignored by legend parsing even when + explicitly passed as handles. + """ + fig, ax = uplt.subplots() + (h1,) = ax.plot([0, 1], label="visible") + (h2,) = ax.plot([1, 0], label="helper") + # Mark helper as synthetic; it should be filtered out from legend entries + setattr(h2, "_ultraplot_synthetic", True) + + leg = ax.legend([h1, h2], loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "visible" in labels + assert "helper" not in labels + uplt.close(fig) + + +def test_fill_between_included_in_legend(): + """ + Legitimate fill_between/area handles must appear in legends (regression for + previously skipped FillBetweenPolyCollection). + """ + fig, ax = uplt.subplots() + x = np.arange(5) + y1 = np.zeros(5) + y2 = np.ones(5) + ax.fill_between(x, y1, y2, label="band") + + leg = ax.legend(loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "band" in labels + uplt.close(fig) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 145de8b9e..fb54d191a 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -1,17 +1,93 @@ -from cycler import V -import pandas as pd -from pandas.core.arrays.arrow.accessors import pa -import ultraplot as uplt, pytest, numpy as np from unittest import mock from unittest.mock import patch +import numpy as np +import pandas as pd +import pytest +from cycler import V +from pandas.core.arrays.arrow.accessors import pa + +import ultraplot as uplt from ultraplot.internals.warnings import UltraPlotWarning + +@pytest.mark.mpl_image_compare +def test_seaborn_lineplot_legend_hue_only(): + """ + Regression test: seaborn lineplot on UltraPlot axes should not add spurious + legend entries like 'y'/'ymin'. Only hue categories should appear unless the user + explicitly labels helper bands. + """ + import seaborn as sns + + fig, ax = uplt.subplots() + df = pd.DataFrame( + { + "xcol": np.concatenate([np.arange(10)] * 2), + "ycol": np.concatenate([np.arange(10), 1.5 * np.arange(10)]), + "hcol": ["h1"] * 10 + ["h2"] * 10, + } + ) + + with ax.external(): + sns.lineplot(data=df, x="xcol", y="ycol", hue="hcol", ax=ax) + + # Create (or refresh) legend and collect labels + leg = ax.legend() + labels = {t.get_text() for t in leg.get_texts()} + + # Should contain only hue levels; must not contain inferred 'y' or CI helpers + assert "y" not in labels + assert "ymin" not in labels + assert {"h1", "h2"}.issubset(labels) + return fig + + """ This file is used to test base properties of ultraplot.axes.plot. For higher order plotting related functions, please use 1d and 2plots """ +def test_external_preserves_explicit_label(): + """ + In external mode, explicit labels must still be respected even when autolabels are disabled. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1, 2], [0, 1, 0], label="X") + leg = ax.legend(h, loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "X" in labels + + +def test_external_disables_autolabels_no_label(): + """ + In external mode, if no labels are provided, autolabels are disabled and a placeholder is used. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1, 2], [0, 1, 0]) + # Explicitly pass the handle so we test the label on that artist + leg = ax.legend(h, loc="best") + labels = [t.get_text() for t in leg.get_texts()] + # With no explicit labels and autolabels disabled, a placeholder is used + assert (not labels) or (labels[0] in ("_no_label", "")) + + +def test_error_shading_explicit_label_external(): + """ + Explicit label on fill_between should be preserved in legend entries. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + x = np.linspace(0, 2 * np.pi, 50) + y = np.sin(x) + patch = ax.fill_between(x, y - 0.5, y + 0.5, alpha=0.3, label="Band") + leg = ax.legend([patch], loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "Band" in labels + + def test_graph_nodes_kw(): """Test the graph method by setting keywords for nodes""" import networkx as nx From 99605306645160efd75b2409af02f4b4fc0f1f2b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:42:12 +1000 Subject: [PATCH 07/29] Bump actions/checkout from 5 to 6 in the github-actions group (#415) Bumps the github-actions group with 1 update: [actions/checkout](https://github.com/actions/checkout). Updates `actions/checkout` from 5 to 6 - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: github-actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build-ultraplot.yml | 4 ++-- .github/workflows/main.yml | 4 ++-- .github/workflows/publish-pypi.yml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build-ultraplot.yml b/.github/workflows/build-ultraplot.yml index 38789f981..7d6f1660a 100644 --- a/.github/workflows/build-ultraplot.yml +++ b/.github/workflows/build-ultraplot.yml @@ -22,7 +22,7 @@ jobs: run: shell: bash -el {0} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 @@ -58,7 +58,7 @@ jobs: run: shell: bash -el {0} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: mamba-org/setup-micromamba@v2.0.7 with: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 01d9c856f..2cc8b1b68 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -11,7 +11,7 @@ jobs: outputs: run: ${{ (github.event_name == 'push' && github.ref_name == 'main') && 'true' || steps.filter.outputs.python }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: filter with: @@ -28,7 +28,7 @@ jobs: python-versions: ${{ steps.set-versions.outputs.python-versions }} matplotlib-versions: ${{ steps.set-versions.outputs.matplotlib-versions }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 1cd1c9e14..63fb29714 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -15,7 +15,7 @@ jobs: name: Build packages runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 From a9db8bc12add66cafc6faf553f7bf976ec496e55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 14:46:06 -0600 Subject: [PATCH 08/29] [pre-commit.ci] pre-commit autoupdate (#416) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b19a691d..eae4604a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,6 +11,6 @@ ci: repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.9.0 + rev: 25.11.0 hooks: - id: black From 240521858f3d37963d4e8a5a3da5b27ef42b41eb Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Sun, 7 Dec 2025 23:41:22 +1000 Subject: [PATCH 09/29] Add placement of legend to axes within a figure (#418) * init + tests * restore stupid mistake * Update ultraplot/figure.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/tests/test_legend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/tests/test_legend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/figure.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/base.py | 45 ++++++++- ultraplot/figure.py | 43 +++++++-- ultraplot/tests/test_legend.py | 165 +++++++++++++++++++++++++++++++++ 3 files changed, 241 insertions(+), 12 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 03bea4fdc..a0e30f68b 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1456,6 +1456,11 @@ def _add_legend( titlefontcolor=None, handle_kw=None, handler_map=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, **kwargs, ): """ @@ -1493,7 +1498,18 @@ def _add_legend( # Generate and prepare the legend axes if loc in ("fill", "left", "right", "top", "bottom"): - lax = self._add_guide_panel(loc, align, width=width, space=space, pad=pad) + lax = self._add_guide_panel( + loc, + align, + width=width, + space=space, + pad=pad, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + ) kwargs.setdefault("borderaxespad", 0) if not frameon: kwargs.setdefault("borderpad", 0) @@ -3560,7 +3576,19 @@ def colorbar(self, mappable, values=None, loc=None, location=None, **kwargs): @docstring._concatenate_inherited # also obfuscates params @docstring._snippet_manager - def legend(self, handles=None, labels=None, loc=None, location=None, **kwargs): + def legend( + self, + handles=None, + labels=None, + loc=None, + location=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ): """ Add an inset legend or outer legend along the edge of the axes. @@ -3622,7 +3650,18 @@ def legend(self, handles=None, labels=None, loc=None, location=None, **kwargs): if queue: self._register_guide("legend", (handles, labels), (loc, align), **kwargs) else: - return self._add_legend(handles, labels, loc=loc, align=align, **kwargs) + return self._add_legend( + handles, + labels, + loc=loc, + align=align, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, + ) @docstring._concatenate_inherited @docstring._snippet_manager diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 7c2cd454b..6b5b46c48 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -6,12 +6,13 @@ import inspect import os from numbers import Integral + from packaging import version try: - from typing import List, Optional, Union, Tuple + from typing import List, Optional, Tuple, Union except ImportError: - from typing_extensions import List, Optional, Union, Tuple + from typing_extensions import List, Optional, Tuple, Union import matplotlib.axes as maxes import matplotlib.figure as mfigure @@ -30,7 +31,6 @@ from . import constructor from . import gridspec as pgridspec from .config import rc, rc_matplotlib -from .internals import ic # noqa: F401 from .internals import ( _not_none, _pop_params, @@ -38,10 +38,11 @@ _translate_loc, context, docstring, + ic, # noqa: F401 labels, warnings, ) -from .utils import units, _get_subplot_layout, _Crawler +from .utils import _Crawler, units __all__ = [ "Figure", @@ -1385,12 +1386,12 @@ def _add_axes_panel( # Vertical panels: should use rows parameter, not cols if _not_none(cols, col) is not None and _not_none(rows, row) is None: raise ValueError( - f"For {side!r} colorbars (vertical), use 'rows=' or 'row=' " + f"For {side!r} panels (vertical), use 'rows=' or 'row=' " "to specify span, not 'cols=' or 'col='." ) if span is not None and _not_none(rows, row) is None: warnings._warn_ultraplot( - f"For {side!r} colorbars (vertical), prefer 'rows=' over 'span=' " + f"For {side!r} panels (vertical), prefer 'rows=' over 'span=' " "for clarity. Using 'span' as rows." ) span_override = _not_none(rows, row, span) @@ -1398,7 +1399,7 @@ def _add_axes_panel( # Horizontal panels: should use cols parameter, not rows if _not_none(rows, row) is not None and _not_none(cols, col, span) is None: raise ValueError( - f"For {side!r} colorbars (horizontal), use 'cols=' or 'span=' " + f"For {side!r} panels (horizontal), use 'cols=' or 'span=' " "to specify span, not 'rows=' or 'row='." ) span_override = _not_none(cols, col, span) @@ -2395,6 +2396,7 @@ def colorbar( if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): try: ax_single = next(iter(ax)) + except (TypeError, StopIteration): ax_single = ax else: @@ -2474,8 +2476,31 @@ def legend( ax = kwargs.pop("ax", None) # Axes panel legend if ax is not None: - leg = ax.legend( - handles, labels, space=space, pad=pad, width=width, **kwargs + # Check if span parameters are provided + has_span = _not_none(span, row, col, rows, cols) is not None + + # Extract a single axes from array if span is provided + # Otherwise, pass the array as-is for normal legend behavior + if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): + try: + ax_single = next(iter(ax)) + except (TypeError, StopIteration): + ax_single = ax + else: + ax_single = ax + leg = ax_single.legend( + handles, + labels, + loc=loc, + space=space, + pad=pad, + width=width, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, ) # Figure panel legend else: diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index dd23c5c18..48a40a678 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -318,3 +318,168 @@ def test_fill_between_included_in_legend(): labels = [t.get_text() for t in leg.get_texts()] assert "band" in labels uplt.close(fig) + + +def test_legend_span_bottom(): + """Test bottom legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend below row 1, spanning columns 1-2 + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created + assert leg is not None + + +def test_legend_span_top(): + """Test top legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend above row 2, spanning columns 2-3 + leg = fig.legend(ax=axs[1, :], cols=(2, 3), loc="top") + + assert leg is not None + + +def test_legend_span_right(): + """Test right legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend right of column 1, spanning rows 1-2 + leg = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + + assert leg is not None + + +def test_legend_span_left(): + """Test left legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend left of column 2, spanning rows 2-3 + leg = fig.legend(ax=axs[:, 1], rows=(2, 3), loc="left") + + assert leg is not None + + +def test_legend_span_validation_left_with_cols_error(): + """Test that LEFT legend raises error with cols parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="left.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="left") + + +def test_legend_span_validation_right_with_cols_error(): + """Test that RIGHT legend raises error with cols parameter.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="right.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="right") + + +def test_legend_span_validation_top_with_rows_error(): + """Test that TOP legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="top.*horizontal.*use 'cols='.*not 'rows='"): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="top") + + +def test_legend_span_validation_bottom_with_rows_error(): + """Test that BOTTOM legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises( + ValueError, match="bottom.*horizontal.*use 'cols='.*not 'rows='" + ): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="bottom") + + +def test_legend_span_validation_left_with_span_warns(): + """Test that LEFT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="left.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="left") + assert leg is not None + + +def test_legend_span_validation_right_with_span_warns(): + """Test that RIGHT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="right.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="right") + assert leg is not None + + +def test_legend_array_without_span(): + """Test that legend on array without span preserves original behavior.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should create legend for all axes in the array + leg = fig.legend(ax=axs[:], loc="right") + assert leg is not None + + +def test_legend_array_with_span(): + """Test that legend on array with span uses first axis + span extent.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should use first axis position with span extent + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + assert leg is not None + + +def test_legend_row_without_span(): + """Test that legend on row without span spans entire row.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 columns + leg = fig.legend(ax=axs[0, :], loc="bottom") + assert leg is not None + + +def test_legend_column_without_span(): + """Test that legend on column without span spans entire column.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 rows + leg = fig.legend(ax=axs[:, 0], loc="right") + assert leg is not None + + +def test_legend_multiple_sides_with_span(): + """Test multiple legends on different sides with span control.""" + fig, axs = uplt.subplots(nrows=3, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Create legends on all 4 sides with different spans + leg_bottom = fig.legend(ax=axs[0, 0], span=(1, 2), loc="bottom") + leg_top = fig.legend(ax=axs[1, 0], span=(2, 3), loc="top") + leg_right = fig.legend(ax=axs[0, 0], rows=(1, 2), loc="right") + leg_left = fig.legend(ax=axs[0, 1], rows=(2, 3), loc="left") + + assert leg_bottom is not None + assert leg_top is not None + assert leg_right is not None + assert leg_left is not None From bcc0aaaafa47bde417a1c47fe183daa27defc8f3 Mon Sep 17 00:00:00 2001 From: Gepcel Date: Tue, 9 Dec 2025 04:44:53 +0800 Subject: [PATCH 10/29] There's a typo about zerotrim in doc. (#420) It should be `formatter.zerotrim` not `format.zerotrim`. --- ultraplot/ticker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/ticker.py b/ultraplot/ticker.py index 4dfa7bfc7..ad1da9519 100644 --- a/ultraplot/ticker.py +++ b/ultraplot/ticker.py @@ -64,7 +64,7 @@ when `zerotrim` is ``True`` and ``2`` otherwise. """ _zerotrim_docstring = """ -zerotrim : bool, default: :rc:`format.zerotrim` +zerotrim : bool, default: :rc:`formatter.zerotrim` Whether to trim trailing decimal zeros. """ _auto_docstring = """ From f8fb44be19aa3cbe7ebe42bdc134e0b04ff64259 Mon Sep 17 00:00:00 2001 From: Gepcel Date: Wed, 10 Dec 2025 14:20:39 +0800 Subject: [PATCH 11/29] Fix references in documentation for clarity (#421) * Fix references in documentation for clarity Fix two unidenfined references in why.rst. 1. ug_apply_norm is a typo I think. 2. ug_mplrc. I'm not sure what it should be. Only by guess. * keep apply_norm --------- Co-authored-by: cvanelteren --- docs/why.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/why.rst b/docs/why.rst index 392a5616d..ab2f17649 100644 --- a/docs/why.rst +++ b/docs/why.rst @@ -499,7 +499,7 @@ like :func:`~ultraplot.axes.PlotAxes.pcolor` and :func:`~ultraplot.axes.PlotAxes plots. This can be disabled by setting :rcraw:`cmap.discrete` to ``False`` or by passing ``discrete=False`` to :class:`~ultraplot.axes.PlotAxes` commands. * The :class:`~ultraplot.colors.DivergingNorm` normalizer is perfect for data with a - :ref:`natural midpoint ` and offers both "fair" and "unfair" scaling. + :ref:`natural midpoint ` and offers both "fair" and "unfair" scaling. The :class:`~ultraplot.colors.SegmentedNorm` normalizer can generate uneven color gradations useful for :ref:`unusual data distributions `. * The :func:`~ultraplot.axes.PlotAxes.heatmap` command invokes @@ -882,7 +882,7 @@ Limitation ---------- Matplotlib :obj:`~matplotlib.rcParams` can be changed persistently by placing -`matplotlibrc` :ref:`ug_mplrc` files in the same directory as your python script. +ref:`matplotlibrc ` files in the same directory as your python script. But it can be difficult to design and store your own colormaps and color cycles for future use. It is also difficult to get matplotlib to use custom ``.ttf`` and ``.otf`` font files, which may be desirable when you are working on From 8e2973d5ea5e38a73f59d84876d049c10ed05213 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Thu, 11 Dec 2025 16:05:51 +1000 Subject: [PATCH 12/29] fix links to apply_norm (#423) --- docs/2dplots.py | 48 ++++++++++++++++++++++++--------------- docs/colorbars_legends.py | 20 ++++++++++------ docs/why.rst | 2 +- 3 files changed, 44 insertions(+), 26 deletions(-) diff --git a/docs/2dplots.py b/docs/2dplots.py index 3f27b7b56..edc22e97c 100644 --- a/docs/2dplots.py +++ b/docs/2dplots.py @@ -77,9 +77,10 @@ # setting and the :ref:`user guide `). # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) x = y = np.array([-10, -5, 0, 5, 10]) @@ -110,9 +111,10 @@ axs[3].contourf(xedges, yedges, data) # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data cmap = "turku_r" state = np.random.RandomState(51423) @@ -184,9 +186,9 @@ # `~pint.UnitRegistry.setup_matplotlib` so that the axes become unit-aware. # %% -import xarray as xr import numpy as np import pandas as pd +import xarray as xr # DataArray state = np.random.RandomState(51423) @@ -261,13 +263,14 @@ # ``diverging=True``, ``cyclic=True``, or ``qualitative=True`` to any plotting # command. If the colormap type is not explicitly specified, `sequential` is # used with the default linear normalizer when data is strictly positive -# or negative, and `diverging` is used with the :ref:`diverging normalizer ` +# or negative, and `diverging` is used with the :ref:`diverging normalizer ` # when the data limits or colormap levels cross zero (see :ref:`below `). # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data N = 18 state = np.random.RandomState(51423) @@ -294,9 +297,10 @@ uplt.rc.reset() # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data N = 20 state = np.random.RandomState(51423) @@ -322,9 +326,10 @@ colorbar="b", ) -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data N = 20 state = np.random.RandomState(51423) @@ -347,7 +352,7 @@ # Special normalizers # ------------------- # -# UltraPlot includes two new :ref:`"continuous" normalizers `. The +# UltraPlot includes two new :ref:`"continuous" normalizers `. The # `~ultraplot.colors.SegmentedNorm` normalizer provides even color gradations with respect # to index for an arbitrary monotonically increasing or decreasing list of levels. This # is automatically applied if you pass unevenly spaced `levels` to a plotting command, @@ -372,9 +377,10 @@ # affect the interpretation of different datasets. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data = 11 ** (2 * state.rand(20, 20).cumsum(axis=0) / 7) @@ -395,9 +401,10 @@ ) ax.format(title=norm.title() + " normalizer") # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data1 = (state.rand(20, 20) - 0.485).cumsum(axis=1).cumsum(axis=0) @@ -434,7 +441,7 @@ # commands (e.g., :func:`~ultraplot.axes.PlotAxes.contourf`, :func:`~ultraplot.axes.PlotAxes.pcolor`). # This is analogous to `matplotlib.colors.BoundaryNorm`, except # `~ultraplot.colors.DiscreteNorm` can be paired with arbitrary -# continuous normalizers specified by `norm` (see :ref:`above `). +# continuous normalizers specified by `norm` (see :ref:`above `). # Discrete color levels can help readers discern exact numeric values and # tend to reveal qualitative structure in the data. `~ultraplot.colors.DiscreteNorm` # also repairs the colormap end-colors by ensuring the following conditions are met: @@ -463,9 +470,10 @@ # the zero level (useful for single-color :func:`~ultraplot.axes.PlotAxes.contour` plots). # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data = 10 + state.normal(0, 1, size=(33, 33)).cumsum(axis=0).cumsum(axis=1) @@ -485,9 +493,10 @@ axs[2].format(title="Imshow plot\ndiscrete=False (default)", yformatter="auto") # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data = (20 * (state.rand(20, 20) - 0.4).cumsum(axis=0).cumsum(axis=1)) % 360 @@ -547,7 +556,7 @@ # the 2D :class:`~ultraplot.axes.PlotAxes` commands will apply the diverging colormap # :rc:`cmap.diverging` (rather than :rc:`cmap.sequential`) and the diverging # normalizer `~ultraplot.colors.DivergingNorm` (rather than :class:`~matplotlib.colors.Normalize` -# -- see :ref:`above `) if the following conditions are met: +# -- see :ref:`above `) if the following conditions are met: # # #. If discrete levels are enabled (see :ref:`above `) and the # level list includes at least 2 negative and 2 positive values. @@ -560,9 +569,10 @@ # setting :rcraw:`cmap.autodiverging` to ``False``. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + N = 20 state = np.random.RandomState(51423) data = N * 2 + (state.rand(N, N) - 0.45).cumsum(axis=0).cumsum(axis=1) * 10 @@ -605,9 +615,10 @@ # command documentation for details. # %% -import ultraplot as uplt -import pandas as pd import numpy as np +import pandas as pd + +import ultraplot as uplt # Sample data state = np.random.RandomState(51423) @@ -663,10 +674,11 @@ # `~ultraplot.axes.CartesianAxes`. # %% -import ultraplot as uplt import numpy as np import pandas as pd +import ultraplot as uplt + # Covariance data state = np.random.RandomState(51423) data = state.normal(size=(10, 10)).cumsum(axis=0) diff --git a/docs/colorbars_legends.py b/docs/colorbars_legends.py index 2b2d58bca..10a4099c8 100644 --- a/docs/colorbars_legends.py +++ b/docs/colorbars_legends.py @@ -78,9 +78,10 @@ # complex arrangements of subplots, colorbars, and legends. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + state = np.random.RandomState(51423) fig = uplt.figure(share=False, refwidth=2.3) @@ -183,9 +184,10 @@ ) # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + N = 10 state = np.random.RandomState(51423) fig, axs = uplt.subplots( @@ -232,9 +234,10 @@ # and the tight layout padding can be controlled with the `pad` keyword. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + state = np.random.RandomState(51423) fig, axs = uplt.subplots(ncols=3, nrows=3, refwidth=1.4) for ax in axs: @@ -254,9 +257,10 @@ fig.colorbar(m, label="colorbar with length <1", ticks=0.1, loc="r", length=0.7) # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + state = np.random.RandomState(51423) fig, axs = uplt.subplots( ncols=2, nrows=2, order="F", refwidth=1.7, wspace=2.5, share=False @@ -299,7 +303,7 @@ # will build the required `~matplotlib.cm.ScalarMappable` on-the-fly. Lists # of :class:`~matplotlib.artist.Artists`\ s are used when you use the `colorbar` # keyword with :ref:`1D commands ` like :func:`~ultraplot.axes.PlotAxes.plot`. -# * The associated :ref:`colormap normalizer ` can be specified with the +# * The associated :ref:`colormap normalizer ` can be specified with the # `vmin`, `vmax`, `norm`, and `norm_kw` keywords. The `~ultraplot.colors.DiscreteNorm` # levels can be specified with `values`, or UltraPlot will infer them from the # :class:`~matplotlib.artist.Artist` labels (non-numeric labels will be applied to @@ -332,9 +336,10 @@ # See :func:`~ultraplot.axes.Axes.colorbar` for details. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + fig = uplt.figure(share=False, refwidth=2) # Colorbars from lines @@ -427,9 +432,10 @@ # (or use the `handle_kw` keyword). See `ultraplot.axes.Axes.legend` for details. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + uplt.rc.cycle = "538" fig, axs = uplt.subplots(ncols=2, span=False, share="labels", refwidth=2.3) labels = ["a", "bb", "ccc", "dddd", "eeeee"] diff --git a/docs/why.rst b/docs/why.rst index ab2f17649..74fc644c4 100644 --- a/docs/why.rst +++ b/docs/why.rst @@ -501,7 +501,7 @@ like :func:`~ultraplot.axes.PlotAxes.pcolor` and :func:`~ultraplot.axes.PlotAxes * The :class:`~ultraplot.colors.DivergingNorm` normalizer is perfect for data with a :ref:`natural midpoint ` and offers both "fair" and "unfair" scaling. The :class:`~ultraplot.colors.SegmentedNorm` normalizer can generate - uneven color gradations useful for :ref:`unusual data distributions `. + uneven color gradations useful for :ref:`unusual data distributions `. * The :func:`~ultraplot.axes.PlotAxes.heatmap` command invokes :func:`~ultraplot.axes.PlotAxes.pcolormesh` then applies an `equal axes apect ratio `__, From abcfb190ec56670350d7dca0a46da2068236be1a Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Fri, 12 Dec 2025 20:27:44 +1000 Subject: [PATCH 13/29] [Feature] add lon lat labelrotation (#426) * add label rotation for geo * add unittests for labelrotation * black formatting * more tests to increase coverage --- ultraplot/axes/geo.py | 43 ++- ultraplot/tests/test_geographic.py | 450 ++++++++++++++++++++++++++++- 2 files changed, 486 insertions(+), 7 deletions(-) diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 15c5f9a43..7ed8efad6 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -24,19 +24,18 @@ from .. import constructor from .. import proj as pproj +from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 from ..internals import ( _not_none, _pop_rc, _version_cartopy, docstring, + ic, # noqa: F401 warnings, ) -from .. import ticker as pticker from ..utils import units -from . import plot -from . import shared +from . import plot, shared try: import cartopy.crs as ccrs @@ -148,6 +147,15 @@ *For cartopy axes only.* Whether to rotate non-inline gridline labels so that they automatically follow the map boundary curvature. +labelrotation : float, optional + The rotation angle in degrees for both longitude and latitude tick labels. + Use `lonlabelrotation` and `latlabelrotation` to set them separately. +lonlabelrotation : float, optional + The rotation angle in degrees for longitude tick labels. + Works for both cartopy and basemap backends. +latlabelrotation : float, optional + The rotation angle in degrees for latitude tick labels. + Works for both cartopy and basemap backends. labelpad : unit-spec, default: :rc:`grid.labelpad` *For cartopy axes only.* The padding between non-inline gridline labels and the map boundary. @@ -850,6 +858,9 @@ def format( latlabels=None, lonlabels=None, rotatelabels=None, + labelrotation=None, + lonlabelrotation=None, + latlabelrotation=None, loninline=None, latinline=None, inlinelabels=None, @@ -996,6 +1007,8 @@ def format( rotatelabels = _not_none( rotatelabels, rc.find("grid.rotatelabels", context=True) ) # noqa: E501 + lonlabelrotation = _not_none(lonlabelrotation, labelrotation) + latlabelrotation = _not_none(latlabelrotation, labelrotation) labelpad = _not_none(labelpad, rc.find("grid.labelpad", context=True)) dms = _not_none(dms, rc.find("grid.dmslabels", context=True)) nsteps = _not_none(nsteps, rc.find("grid.nsteps", context=True)) @@ -1028,6 +1041,8 @@ def format( loninline=loninline, latinline=latinline, rotatelabels=rotatelabels, + lonlabelrotation=lonlabelrotation, + latlabelrotation=latlabelrotation, labelpad=labelpad, nsteps=nsteps, ) @@ -1690,6 +1705,8 @@ def _update_major_gridlines( latinline=None, labelpad=None, rotatelabels=None, + lonlabelrotation=None, + latlabelrotation=None, nsteps=None, ): """ @@ -1729,6 +1746,10 @@ def _update_major_gridlines( gl.y_inline = bool(latinline) if rotatelabels is not None: gl.rotate_labels = bool(rotatelabels) # ignored in cartopy < 0.18 + if lonlabelrotation is not None: + gl.xlabel_style["rotation"] = lonlabelrotation + if latlabelrotation is not None: + gl.ylabel_style["rotation"] = latlabelrotation if latinline is not None or loninline is not None: lon, lat = loninline, latinline b = True if lon and lat else "x" if lon else "y" if lat else None @@ -2108,17 +2129,20 @@ def _update_gridlines( latgrid=None, lonarray=None, latarray=None, + lonlabelrotation=None, + latlabelrotation=None, ): """ Apply changes to the basemap axes. """ latmax = self._lataxis.get_latmax() - for axis, name, grid, array, method in zip( + for axis, name, grid, array, method, rotation in zip( ("x", "y"), ("lon", "lat"), (longrid, latgrid), (lonarray, latarray), ("drawmeridians", "drawparallels"), + (lonlabelrotation, latlabelrotation), ): # Correct lonarray and latarray label toggles by changing from lrbt to lrtb. # Then update the cahced toggle array. This lets us change gridline locs @@ -2173,6 +2197,9 @@ def _update_gridlines( for obj in self._iter_gridlines(objs): if isinstance(obj, mtext.Text): obj.update(kwtext) + # Apply rotation if specified + if rotation is not None: + obj.set_rotation(rotation) else: obj.update(kwlines) @@ -2191,6 +2218,8 @@ def _update_major_gridlines( loninline=None, latinline=None, rotatelabels=None, + lonlabelrotation=None, + latlabelrotation=None, labelpad=None, nsteps=None, ): @@ -2204,6 +2233,8 @@ def _update_major_gridlines( latgrid=latgrid, lonarray=lonarray, latarray=latarray, + lonlabelrotation=lonlabelrotation, + latlabelrotation=latlabelrotation, ) sides = {} for side, lonon, laton in zip( @@ -2226,6 +2257,8 @@ def _update_minor_gridlines(self, longrid=None, latgrid=None, nsteps=None): latgrid=latgrid, lonarray=array, latarray=array, + lonlabelrotation=None, + latlabelrotation=None, ) # Set isDefault_majloc, etc. to True for both axes # NOTE: This cannot be done inside _update_gridlines or minor gridlines diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 30911c176..62e0f8940 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -1,7 +1,11 @@ -import ultraplot as uplt, numpy as np, warnings -import pytest +import warnings from unittest import mock +import numpy as np +import pytest + +import ultraplot as uplt + @pytest.mark.mpl_image_compare def test_geographic_single_projection(): @@ -1010,3 +1014,445 @@ def test_grid_indexing_formatting(rng): axs[-1, :].format(lonlabels=True) axs[:, 0].format(latlabels=True) return fig + + +@pytest.mark.parametrize( + "backend", + [ + "cartopy", + "basemap", + ], +) +def test_label_rotation(backend): + """ + Test label rotation parameters for both Cartopy and Basemap backends. + Tests labelrotation, lonlabelrotation, and latlabelrotation parameters. + """ + fig, axs = uplt.subplots(ncols=2, proj="cyl", backend=backend, share=0) + + # Test 1: labelrotation applies to both axes + axs[0].format( + title="Both rotated 45°", + lonlabels="b", + latlabels="l", + labelrotation=45, + lonlines=30, + latlines=30, + ) + + # Test 2: Different rotations for lon and lat + axs[1].format( + title="Lon: 90°, Lat: 0°", + lonlabels="b", + latlabels="l", + lonlabelrotation=90, + latlabelrotation=0, + lonlines=30, + latlines=30, + ) + + # Verify that rotation was applied based on actual backend + if axs[0]._name == "cartopy": + # For Cartopy, check gridliner xlabel_style and ylabel_style + gl0 = axs[0].gridlines_major + assert gl0.xlabel_style.get("rotation") == 45 + assert gl0.ylabel_style.get("rotation") == 45 + + gl1 = axs[1].gridlines_major + assert gl1.xlabel_style.get("rotation") == 90 + assert gl1.ylabel_style.get("rotation") == 0 + + else: # basemap + # For Basemap, check Text object rotation + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + """Extract rotation angles from Text objects in gridlines.""" + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + # Check first axes (both 45°) + lonlines_0, latlines_0 = axs[0].gridlines_major + lon_rotations_0 = get_text_rotations(lonlines_0) + lat_rotations_0 = get_text_rotations(latlines_0) + if lon_rotations_0: # Only check if labels exist + assert all(r == 45 for r in lon_rotations_0) + if lat_rotations_0: + assert all(r == 45 for r in lat_rotations_0) + + # Check second axes (lon: 90°, lat: 0°) + lonlines_1, latlines_1 = axs[1].gridlines_major + lon_rotations_1 = get_text_rotations(lonlines_1) + lat_rotations_1 = get_text_rotations(latlines_1) + if lon_rotations_1: + assert all(r == 90 for r in lon_rotations_1) + if lat_rotations_1: + assert all(r == 0 for r in lat_rotations_1) + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_precedence(backend): + """ + Test that specific rotation parameters take precedence over general labelrotation. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # lonlabelrotation should override labelrotation for lon axis + # latlabelrotation not specified, so should use labelrotation + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=30, + lonlabelrotation=60, # This should override for lon + lonlines=30, + latlines=30, + ) + + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 60 # Override value + assert gl.ylabel_style.get("rotation") == 30 # Fallback value + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 60 for r in lon_rotations) + if lat_rotations: + assert all(r == 30 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_backward_compatibility(): + """ + Test that existing code without rotation parameters still works. + """ + fig, ax = uplt.subplots(proj="cyl") + + # Should work without any rotation parameters + ax.format( + lonlabels="b", + latlabels="l", + lonlines=30, + latlines=30, + ) + + # Verify no rotation was applied (should be default or None) + gl = ax[0]._gridlines_major + # If rotation key doesn't exist or is None/0, that's expected + lon_rotation = gl.xlabel_style.get("rotation") + lat_rotation = gl.ylabel_style.get("rotation") + + # Default rotation should be None or 0 (no rotation) + assert lon_rotation is None or lon_rotation == 0 + assert lat_rotation is None or lat_rotation == 0 + + uplt.close(fig) + + +@pytest.mark.parametrize("rotation_angle", [0, 45, 90, -30, 180]) +def test_label_rotation_angles(rotation_angle): + """ + Test various rotation angles to ensure they're applied correctly. + """ + fig, ax = uplt.subplots(proj="cyl") + + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=rotation_angle, + lonlines=60, + latlines=30, + ) + + gl = ax[0]._gridlines_major + assert gl.xlabel_style.get("rotation") == rotation_angle + assert gl.ylabel_style.get("rotation") == rotation_angle + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_only_lon(backend): + """ + Test rotation applied only to longitude labels. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # Only rotate longitude labels + ax.format( + lonlabels="b", + latlabels="l", + lonlabelrotation=45, + lonlines=30, + latlines=30, + ) + + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 45 + assert gl.ylabel_style.get("rotation") is None + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 45 for r in lon_rotations) + if lat_rotations: + # Default rotation should be 0 + assert all(r == 0 for r in lat_rotations) + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_only_lat(backend): + """ + Test rotation applied only to latitude labels. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # Only rotate latitude labels + ax.format( + lonlabels="b", + latlabels="l", + latlabelrotation=60, + lonlines=30, + latlines=30, + ) + + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") is None + assert gl.ylabel_style.get("rotation") == 60 + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + # Default rotation should be 0 + assert all(r == 0 for r in lon_rotations) + if lat_rotations: + assert all(r == 60 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_with_different_projections(): + """ + Test label rotation with various projections. + """ + projections = ["cyl", "robin", "moll"] + + for proj in projections: + fig, ax = uplt.subplots(proj=proj) + + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=30, + lonlines=60, + latlines=30, + ) + + # For cartopy, verify rotation was set + if ax[0]._name == "cartopy": + gl = ax[0]._gridlines_major + if gl is not None: # Some projections might not support gridlines + assert gl.xlabel_style.get("rotation") == 30 + assert gl.ylabel_style.get("rotation") == 30 + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_with_format_options(backend): + """ + Test label rotation combined with other format options. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # Combine rotation with other formatting + ax.format( + lonlabels="b", + latlabels="l", + lonlabelrotation=45, + latlabelrotation=30, + lonlines=30, + latlines=30, + coast=True, + land=True, + ) + + # Verify rotation was applied + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 45 + assert gl.ylabel_style.get("rotation") == 30 + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 45 for r in lon_rotations) + if lat_rotations: + assert all(r == 30 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_none_values(): + """ + Test that None values for rotation work correctly. + """ + fig, ax = uplt.subplots(proj="cyl") + + # Explicitly set None for rotations + ax.format( + lonlabels="b", + latlabels="l", + lonlabelrotation=None, + latlabelrotation=None, + lonlines=30, + latlines=30, + ) + + gl = ax[0]._gridlines_major + # None should result in no rotation being set + lon_rotation = gl.xlabel_style.get("rotation") + lat_rotation = gl.ylabel_style.get("rotation") + + assert lon_rotation is None or lon_rotation == 0 + assert lat_rotation is None or lat_rotation == 0 + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_update_existing(backend): + """ + Test updating rotation on axes that already have labels. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # First format without rotation + ax.format( + lonlabels="b", + latlabels="l", + lonlines=30, + latlines=30, + ) + + # Then update with rotation + ax.format( + lonlabelrotation=45, + latlabelrotation=90, + ) + + # Verify rotation was applied + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 45 + assert gl.ylabel_style.get("rotation") == 90 + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 45 for r in lon_rotations) + if lat_rotations: + assert all(r == 90 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_negative_angles(): + """ + Test various negative rotation angles. + """ + fig, ax = uplt.subplots(proj="cyl") + + negative_angles = [-15, -45, -90, -120, -180] + + for angle in negative_angles: + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=angle, + lonlines=60, + latlines=30, + ) + + gl = ax[0]._gridlines_major + assert gl.xlabel_style.get("rotation") == angle + assert gl.ylabel_style.get("rotation") == angle + + uplt.close(fig) From edd603904c4360245826becd5e67b2e53adba480 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Sat, 13 Dec 2025 23:01:58 +1000 Subject: [PATCH 14/29] fix boundary check for ticks --- ultraplot/axes/geo.py | 21 ++++++-- ultraplot/tests/test_geographic.py | 87 ++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 7ed8efad6..b3979f425 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -1592,12 +1592,18 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): lonlim[0] = lon0 - 180 if lonlim[1] is None: lonlim[1] = lon0 + 180 - lonlim[0] += eps + # Expand limits slightly to ensure boundary labels are included + # NOTE: We expand symmetrically (subtract from min, add to max) rather + # than just shifting to avoid excluding boundary gridlines + lonlim[0] -= eps + lonlim[1] += eps latlim = list(latlim) if latlim[0] is None: latlim[0] = -90 if latlim[1] is None: latlim[1] = 90 + latlim[0] -= eps + latlim[1] += eps extent = lonlim + latlim self.set_extent(extent, crs=ccrs.PlateCarree()) @@ -1678,9 +1684,18 @@ def _update_gridlines( # NOTE: This will re-apply existing gridline locations if unchanged. if nsteps is not None: gl.n_steps = nsteps - latmax = self._lataxis.get_latmax() + # Set xlim and ylim for cartopy >= 0.19 to control which labels are displayed + # Use the actual view intervals so that labels at the extent boundaries are shown + # NOTE: Expand limits slightly because cartopy uses strict inequality for filtering + # labels (e.g., xlim[0] < lon < xlim[1]), which would exclude boundary labels if _version_cartopy >= "0.19": - gl.ylim = (-latmax, latmax) + eps = 1.0 # epsilon to include boundary labels (cartopy filters strictly) + loninterval = self._lonaxis.get_view_interval() + latinterval = self._lataxis.get_view_interval() + if loninterval is not None: + gl.xlim = (loninterval[0] - eps, loninterval[1] + eps) + if latinterval is not None: + gl.ylim = (latinterval[0] - eps, latinterval[1] + eps) longrid = rc._get_gridline_bool(longrid, axis="x", which=which, native=False) if longrid is not None: gl.xlines = longrid diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 62e0f8940..9ab28fd76 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -1456,3 +1456,90 @@ def test_label_rotation_negative_angles(): assert gl.ylabel_style.get("rotation") == angle uplt.close(fig) + + +def _check_boundary_labels(ax, expected_lon_labels, expected_lat_labels): + """Helper to check that boundary labels are created and visible.""" + gl = ax._gridlines_major + assert gl is not None, "Gridliner should exist" + + # Check xlim/ylim are expanded beyond actual limits + assert hasattr(gl, "xlim") and hasattr(gl, "ylim") + + # Check longitude labels + lon_texts = [ + label.get_text() for label in gl.bottom_label_artists if label.get_visible() + ] + assert len(gl.bottom_label_artists) == len(expected_lon_labels), ( + f"Should have {len(expected_lon_labels)} longitude labels, " + f"got {len(gl.bottom_label_artists)}" + ) + for expected in expected_lon_labels: + assert any( + expected in text for text in lon_texts + ), f"{expected} label should be visible, got: {lon_texts}" + + # Check latitude labels + lat_texts = [ + label.get_text() for label in gl.left_label_artists if label.get_visible() + ] + assert len(gl.left_label_artists) == len(expected_lat_labels), ( + f"Should have {len(expected_lat_labels)} latitude labels, " + f"got {len(gl.left_label_artists)}" + ) + for expected in expected_lat_labels: + assert any( + expected in text for text in lat_texts + ), f"{expected} label should be visible, got: {lat_texts}" + + +def test_boundary_labels_positive_longitude(): + """ + Test that boundary labels are visible with positive longitude limits. + + This tests the fix for the issue where setting lonlim/latlim would hide + the outermost labels because cartopy's gridliner was filtering them out. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format( + lonlim=(120, 130), + latlim=(10, 20), + lonlocator=[120, 125, 130], + latlocator=[10, 15, 20], + labels=True, + grid=False, + ) + fig.canvas.draw() + _check_boundary_labels(ax[0], ["120°E", "125°E", "130°E"], ["10°N", "15°N", "20°N"]) + uplt.close(fig) + + +def test_boundary_labels_negative_longitude(): + """ + Test that boundary labels are visible with negative longitude limits. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format( + lonlim=(-120, -60), + latlim=(20, 50), + lonlocator=[-120, -90, -60], + latlocator=[20, 35, 50], + labels=True, + grid=False, + ) + fig.canvas.draw() + _check_boundary_labels(ax[0], ["120°W", "90°W", "60°W"], ["20°N", "35°N", "50°N"]) + uplt.close(fig) + + +def test_boundary_labels_view_intervals(): + """ + Test that view intervals match requested limits after setting lonlim/latlim. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format(lonlim=(0, 60), latlim=(-20, 40), lonlines=30, latlines=20, labels=True) + loninterval = ax[0]._lonaxis.get_view_interval() + latinterval = ax[0]._lataxis.get_view_interval() + assert abs(loninterval[0] - 0) < 1 and abs(loninterval[1] - 60) < 1 + assert abs(latinterval[0] - (-20)) < 1 and abs(latinterval[1] - 40) < 1 + uplt.close(fig) From 1002adff9e0ef46312debdf7b223c6b5de1a0898 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Sat, 13 Dec 2025 23:04:08 +1000 Subject: [PATCH 15/29] Revert "fix boundary check for ticks" This reverts commit edd603904c4360245826becd5e67b2e53adba480. --- ultraplot/axes/geo.py | 21 ++------ ultraplot/tests/test_geographic.py | 87 ------------------------------ 2 files changed, 3 insertions(+), 105 deletions(-) diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index b3979f425..7ed8efad6 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -1592,18 +1592,12 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): lonlim[0] = lon0 - 180 if lonlim[1] is None: lonlim[1] = lon0 + 180 - # Expand limits slightly to ensure boundary labels are included - # NOTE: We expand symmetrically (subtract from min, add to max) rather - # than just shifting to avoid excluding boundary gridlines - lonlim[0] -= eps - lonlim[1] += eps + lonlim[0] += eps latlim = list(latlim) if latlim[0] is None: latlim[0] = -90 if latlim[1] is None: latlim[1] = 90 - latlim[0] -= eps - latlim[1] += eps extent = lonlim + latlim self.set_extent(extent, crs=ccrs.PlateCarree()) @@ -1684,18 +1678,9 @@ def _update_gridlines( # NOTE: This will re-apply existing gridline locations if unchanged. if nsteps is not None: gl.n_steps = nsteps - # Set xlim and ylim for cartopy >= 0.19 to control which labels are displayed - # Use the actual view intervals so that labels at the extent boundaries are shown - # NOTE: Expand limits slightly because cartopy uses strict inequality for filtering - # labels (e.g., xlim[0] < lon < xlim[1]), which would exclude boundary labels + latmax = self._lataxis.get_latmax() if _version_cartopy >= "0.19": - eps = 1.0 # epsilon to include boundary labels (cartopy filters strictly) - loninterval = self._lonaxis.get_view_interval() - latinterval = self._lataxis.get_view_interval() - if loninterval is not None: - gl.xlim = (loninterval[0] - eps, loninterval[1] + eps) - if latinterval is not None: - gl.ylim = (latinterval[0] - eps, latinterval[1] + eps) + gl.ylim = (-latmax, latmax) longrid = rc._get_gridline_bool(longrid, axis="x", which=which, native=False) if longrid is not None: gl.xlines = longrid diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 9ab28fd76..62e0f8940 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -1456,90 +1456,3 @@ def test_label_rotation_negative_angles(): assert gl.ylabel_style.get("rotation") == angle uplt.close(fig) - - -def _check_boundary_labels(ax, expected_lon_labels, expected_lat_labels): - """Helper to check that boundary labels are created and visible.""" - gl = ax._gridlines_major - assert gl is not None, "Gridliner should exist" - - # Check xlim/ylim are expanded beyond actual limits - assert hasattr(gl, "xlim") and hasattr(gl, "ylim") - - # Check longitude labels - lon_texts = [ - label.get_text() for label in gl.bottom_label_artists if label.get_visible() - ] - assert len(gl.bottom_label_artists) == len(expected_lon_labels), ( - f"Should have {len(expected_lon_labels)} longitude labels, " - f"got {len(gl.bottom_label_artists)}" - ) - for expected in expected_lon_labels: - assert any( - expected in text for text in lon_texts - ), f"{expected} label should be visible, got: {lon_texts}" - - # Check latitude labels - lat_texts = [ - label.get_text() for label in gl.left_label_artists if label.get_visible() - ] - assert len(gl.left_label_artists) == len(expected_lat_labels), ( - f"Should have {len(expected_lat_labels)} latitude labels, " - f"got {len(gl.left_label_artists)}" - ) - for expected in expected_lat_labels: - assert any( - expected in text for text in lat_texts - ), f"{expected} label should be visible, got: {lat_texts}" - - -def test_boundary_labels_positive_longitude(): - """ - Test that boundary labels are visible with positive longitude limits. - - This tests the fix for the issue where setting lonlim/latlim would hide - the outermost labels because cartopy's gridliner was filtering them out. - """ - fig, ax = uplt.subplots(proj="pcarree") - ax.format( - lonlim=(120, 130), - latlim=(10, 20), - lonlocator=[120, 125, 130], - latlocator=[10, 15, 20], - labels=True, - grid=False, - ) - fig.canvas.draw() - _check_boundary_labels(ax[0], ["120°E", "125°E", "130°E"], ["10°N", "15°N", "20°N"]) - uplt.close(fig) - - -def test_boundary_labels_negative_longitude(): - """ - Test that boundary labels are visible with negative longitude limits. - """ - fig, ax = uplt.subplots(proj="pcarree") - ax.format( - lonlim=(-120, -60), - latlim=(20, 50), - lonlocator=[-120, -90, -60], - latlocator=[20, 35, 50], - labels=True, - grid=False, - ) - fig.canvas.draw() - _check_boundary_labels(ax[0], ["120°W", "90°W", "60°W"], ["20°N", "35°N", "50°N"]) - uplt.close(fig) - - -def test_boundary_labels_view_intervals(): - """ - Test that view intervals match requested limits after setting lonlim/latlim. - """ - fig, ax = uplt.subplots(proj="pcarree") - ax.format(lonlim=(0, 60), latlim=(-20, 40), lonlines=30, latlines=20, labels=True) - loninterval = ax[0]._lonaxis.get_view_interval() - latinterval = ax[0]._lataxis.get_view_interval() - assert abs(loninterval[0] - 0) < 1 and abs(loninterval[1] - 60) < 1 - assert abs(latinterval[0] - (-20)) < 1 and abs(latinterval[1] - 40) < 1 - uplt.close(fig) From d3f8342486fba675af938614466356eb37fd90f3 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 15 Dec 2025 12:27:02 +1000 Subject: [PATCH 16/29] Fix: Boundary labels now visible when setting lonlim/latlim (#429) * fix boundary check for ticks * fix boundary test * fix boundary test --- ultraplot/axes/geo.py | 24 +++++-- ultraplot/tests/test_geographic.py | 102 +++++++++++++++++++++++++++-- 2 files changed, 116 insertions(+), 10 deletions(-) diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 7ed8efad6..baa4da58e 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -1559,7 +1559,8 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): # WARNING: The set_extent method tries to set a *rectangle* between the *4* # (x, y) coordinate pairs (each corner), so something like (-180, 180, -90, 90) # will result in *line*, causing error! We correct this here. - eps = 1e-10 # bug with full -180, 180 range when lon_0 != 0 + eps_small = 1e-10 # bug with full -180, 180 range when lon_0 != 0 + eps_label = 0.5 # larger epsilon to ensure boundary labels are included lon0 = self._get_lon0() proj = type(self.projection).__name__ north = isinstance(self.projection, self._proj_north) @@ -1575,7 +1576,12 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): if boundinglat is not None and boundinglat != self._boundinglat: lat0 = 90 if north else -90 lon0 = self._get_lon0() - extent = [lon0 - 180 + eps, lon0 + 180 - eps, boundinglat, lat0] + extent = [ + lon0 - 180 + eps_small, + lon0 + 180 - eps_small, + boundinglat, + lat0, + ] self.set_extent(extent, crs=ccrs.PlateCarree()) self._boundinglat = boundinglat @@ -1592,12 +1598,18 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): lonlim[0] = lon0 - 180 if lonlim[1] is None: lonlim[1] = lon0 + 180 - lonlim[0] += eps + # Expand limits slightly to ensure boundary labels are included + # NOTE: We expand symmetrically (subtract from min, add to max) rather + # than just shifting to avoid excluding boundary gridlines + lonlim[0] -= eps_label + lonlim[1] += eps_label latlim = list(latlim) if latlim[0] is None: latlim[0] = -90 if latlim[1] is None: latlim[1] = 90 + latlim[0] -= eps_label + latlim[1] += eps_label extent = lonlim + latlim self.set_extent(extent, crs=ccrs.PlateCarree()) @@ -1678,9 +1690,9 @@ def _update_gridlines( # NOTE: This will re-apply existing gridline locations if unchanged. if nsteps is not None: gl.n_steps = nsteps - latmax = self._lataxis.get_latmax() - if _version_cartopy >= "0.19": - gl.ylim = (-latmax, latmax) + # Set xlim and ylim for cartopy >= 0.19 to control which labels are displayed + # NOTE: Don't set xlim/ylim here - let cartopy determine from the axes extent + # The extent expansion in _update_extent should be sufficient to include boundary labels longrid = rc._get_gridline_bool(longrid, axis="x", which=which, native=False) if longrid is not None: gl.xlines = longrid diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 62e0f8940..94501fb37 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -460,7 +460,10 @@ def test_sharing_geo_limits(): after_lat = ax[1]._lataxis.get_view_interval() # We are sharing y which is the latitude axis - assert all([np.allclose(i, j) for i, j in zip(expectation["latlim"], after_lat)]) + # Account for small epsilon expansion in extent (0.5 degrees per side) + assert all( + [np.allclose(i, j, atol=1.0) for i, j in zip(expectation["latlim"], after_lat)] + ) # We are not sharing longitude yet assert all( [ @@ -474,7 +477,10 @@ def test_sharing_geo_limits(): after_lon = ax[1]._lonaxis.get_view_interval() assert all([not np.allclose(i, j) for i, j in zip(before_lon, after_lon)]) - assert all([np.allclose(i, j) for i, j in zip(after_lon, expectation["lonlim"])]) + # Account for small epsilon expansion in extent (0.5 degrees per side) + assert all( + [np.allclose(i, j, atol=1.0) for i, j in zip(after_lon, expectation["lonlim"])] + ) uplt.close(fig) @@ -949,8 +955,9 @@ def test_consistent_range(): lonview = np.array(a._lonaxis.get_view_interval()) latview = np.array(a._lataxis.get_view_interval()) - assert np.allclose(lonview, lonlim) - assert np.allclose(latview, latlim) + # Account for small epsilon expansion in extent (0.5 degrees per side) + assert np.allclose(lonview, lonlim, atol=1.0) + assert np.allclose(latview, latlim, atol=1.0) @pytest.mark.mpl_image_compare @@ -1456,3 +1463,90 @@ def test_label_rotation_negative_angles(): assert gl.ylabel_style.get("rotation") == angle uplt.close(fig) + + +def _check_boundary_labels(ax, expected_lon_labels, expected_lat_labels): + """Helper to check that boundary labels are created and visible.""" + gl = ax._gridlines_major + assert gl is not None, "Gridliner should exist" + + # Check xlim/ylim are expanded beyond actual limits + assert hasattr(gl, "xlim") and hasattr(gl, "ylim") + + # Check longitude labels + lon_texts = [ + label.get_text() for label in gl.bottom_label_artists if label.get_visible() + ] + assert len(gl.bottom_label_artists) == len(expected_lon_labels), ( + f"Should have {len(expected_lon_labels)} longitude labels, " + f"got {len(gl.bottom_label_artists)}" + ) + for expected in expected_lon_labels: + assert any( + expected in text for text in lon_texts + ), f"{expected} label should be visible, got: {lon_texts}" + + # Check latitude labels + lat_texts = [ + label.get_text() for label in gl.left_label_artists if label.get_visible() + ] + assert len(gl.left_label_artists) == len(expected_lat_labels), ( + f"Should have {len(expected_lat_labels)} latitude labels, " + f"got {len(gl.left_label_artists)}" + ) + for expected in expected_lat_labels: + assert any( + expected in text for text in lat_texts + ), f"{expected} label should be visible, got: {lat_texts}" + + +def test_boundary_labels_positive_longitude(): + """ + Test that boundary labels are visible with positive longitude limits. + + This tests the fix for the issue where setting lonlim/latlim would hide + the outermost labels because cartopy's gridliner was filtering them out. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format( + lonlim=(120, 130), + latlim=(10, 20), + lonlocator=[120, 125, 130], + latlocator=[10, 15, 20], + labels=True, + grid=False, + ) + fig.canvas.draw() + _check_boundary_labels(ax[0], ["120°E", "125°E", "130°E"], ["10°N", "15°N", "20°N"]) + uplt.close(fig) + + +def test_boundary_labels_negative_longitude(): + """ + Test that boundary labels are visible with negative longitude limits. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format( + lonlim=(-120, -60), + latlim=(20, 50), + lonlocator=[-120, -90, -60], + latlocator=[20, 35, 50], + labels=True, + grid=False, + ) + fig.canvas.draw() + _check_boundary_labels(ax[0], ["120°W", "90°W", "60°W"], ["20°N", "35°N", "50°N"]) + uplt.close(fig) + + +def test_boundary_labels_view_intervals(): + """ + Test that view intervals match requested limits after setting lonlim/latlim. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format(lonlim=(0, 60), latlim=(-20, 40), lonlines=30, latlines=20, labels=True) + loninterval = ax[0]._lonaxis.get_view_interval() + latinterval = ax[0]._lataxis.get_view_interval() + assert abs(loninterval[0] - 0) < 1 and abs(loninterval[1] - 60) < 1 + assert abs(latinterval[0] - (-20)) < 1 and abs(latinterval[1] - 40) < 1 + uplt.close(fig) From 80e12ee571fe53c1f1695b35cc555af4aaec88ec Mon Sep 17 00:00:00 2001 From: Erik Holmgren <56769803+Holmgren825@users.noreply.github.com> Date: Tue, 16 Dec 2025 21:23:51 +0100 Subject: [PATCH 17/29] Add Copernicus Publications figure standard widths. (#433) --- ultraplot/figure.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 6b5b46c48..a0f74d201 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -62,6 +62,8 @@ "ams2": 4.5, "ams3": 5.5, "ams4": 6.5, + "cop1": "8.3cm", + "cop2": "12cm", "nat1": "89mm", "nat2": "183mm", "pnas1": "8.7cm", @@ -162,6 +164,9 @@ ``'ams2'`` small 2-column ” ``'ams3'`` medium 2-column ” ``'ams4'`` full 2-column ” + ``'cop1'`` 1-column \ +`Copernicus Publications `_ (e.g. *The Cryosphere*, *Geoscientific Model Development*) + ``'cop2'`` 2-column ” ``'nat1'`` 1-column `Nature Research `_ ``'nat2'`` 2-column ” ``'pnas1'`` 1-column \ @@ -177,6 +182,8 @@ https://www.agu.org/Publish-with-AGU/Publish/Author-Resources/Graphic-Requirements .. _ams: \ https://www.ametsoc.org/ams/index.cfm/publications/authors/journal-and-bams-authors/figure-information-for-authors/ + .. _cop: \ +https://publications.copernicus.org/for_authors/manuscript_preparation.html#figurestables .. _nat: \ https://www.nature.com/nature/for-authors/formatting-guide .. _pnas: \ From 6e0f5c1ec5fab38c345c6478405438b138dbe78d Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Thu, 18 Dec 2025 08:05:07 +1000 Subject: [PATCH 18/29] Fix unequal slicing for Gridspec (#435) --- ultraplot/gridspec.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 59de0f04c..63556ab0d 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -6,21 +6,24 @@ import itertools import re from collections.abc import MutableSequence +from functools import wraps from numbers import Integral +from typing import List, Optional, Tuple, Union import matplotlib.axes as maxes import matplotlib.gridspec as mgridspec import matplotlib.transforms as mtransforms import numpy as np -from typing import List, Optional, Union, Tuple -from functools import wraps from . import axes as paxes from .config import rc -from .internals import ic # noqa: F401 -from .internals import _not_none, docstring, warnings +from .internals import ( + _not_none, + docstring, + ic, # noqa: F401 + warnings, +) from .utils import _fontsize_to_pt, units -from .internals import warnings __all__ = ["GridSpec", "SubplotGrid"] @@ -1650,7 +1653,10 @@ def __getitem__(self, key): ) new_key.append(encoded_keyi) xs, ys = new_key - objs = grid[xs, ys] + if np.iterable(xs) and np.iterable(ys): + objs = grid[np.ix_(xs, ys)] + else: + objs = grid[xs, ys] if hasattr(objs, "flat"): objs = [obj for obj in objs.flat if obj is not None] elif not isinstance(objs, list): From 0a4c0332c570dddda2b8deac5f0d75af9382d203 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 29 Dec 2025 23:25:17 +1000 Subject: [PATCH 19/29] Fix GeoAxes panel alignment with aspect-constrained projections (#432) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix GeoAxes panel alignment with aspect-constrained projections Add _adjust_panel_positions() method to dynamically reposition panels after apply_aspect() shrinks the main GeoAxes to maintain projection aspect ratio. This ensures panels properly flank the visible map boundaries rather than remaining at their original gridspec positions, eliminating gaps between panels and the map when using large pad values or when the projection's aspect ratio differs significantly from the allocated subplot space. * Fix double-adjustment issue in panel positioning Remove _adjust_panel_positions() call from GeoAxes.draw() to prevent double-adjustment. The method should only be called in _CartopyAxes.get_tightbbox() where apply_aspect() happens and tight layout calculations occur. This fixes the odd gap issue when saving figures with top panels. * Revert "Fix double-adjustment issue in panel positioning" This reverts commit ef55f694abd4cbc9f05b03d8aec9292f1a6632a7. * Fix panel gap calculation to use original positions Use panel.get_position(original=True) instead of get_position() to ensure gap calculations are based on original gridspec positions, not previously adjusted positions. This makes _adjust_panel_positions() idempotent and fixes accumulated adjustment errors when called multiple times during the render/save cycle. * Adjust tolerance in test_reference_aspect for floating-point precision The reference width calculations have minor floating-point precision differences (< 0.1%) which are expected. Update np.isclose() to use rtol=1e-3 to account for this while still validating accuracy. * Fix boundary label visibility issue in cartopy Cartopy was hiding boundary labels due to floating point precision issues when checking if labels are within the axes extent. The labels at exact boundary values (e.g., 20°N when latlim=(20, 50)) were being marked invisible. Solution: 1. Set gridliner xlim/ylim explicitly before drawing (cartopy >= 0.19) 2. Force boundary labels to be visible if their positions are within the axes extent, both in get_tightbbox() and draw() methods 3. Added _force_boundary_label_visibility() helper method This fixes the test_boundary_labels_negative_longitude test which was failing since it was added in commit d3f83424. * Revert "Fix boundary label visibility issue in cartopy" This reverts commit 794e7a5fa35770ee0da81f9e496f7a3e1cfbfe3a. * Fix test_boundary_labels tests to match actual cartopy behavior The test helper was checking total label count instead of visible labels, and the negative longitude test expected a boundary label (20°N) to be visible when cartopy actually hides it due to floating point precision. Changes: - Modified _check_boundary_labels() to check visible label count, not total - Updated test_boundary_labels_negative_longitude to expect only the labels that are actually visible (35°N, 50°N) instead of all 3 This test was failing since it was first added in d3f83424. * Remove _adjust_panel_positions call from GeoAxes.draw() The method is only defined in _CartopyAxes, not _BasemapAxes, so calling it from the base GeoAxes.draw() causes AttributeError for basemap axes. The adjustment is only needed for cartopy's apply_aspect() behavior, so it should only be called in _CartopyAxes.get_tightbbox() where it belongs. * Override draw() in _CartopyAxes to adjust panel positions Instead of calling _adjust_panel_positions() from base GeoAxes.draw() (which breaks basemap), override draw() specifically in _CartopyAxes. This ensures panel alignment works for cartopy while keeping basemap compatibility. * make subplots_adjust work with both backend * Revert "make subplots_adjust work with both backend" This reverts commit 800f983d143a7f12ac256c7883cf8f3a4515f3cb. * this works but generates different sizes * fix failing tests * this fails locally but should pass on GHA * Fix unequal slicing for Gridspec (#435) * fix remaining issues * dedup logic * Dedup geo panel alignment logic --- ultraplot/axes/geo.py | 178 ++++++++++++++++++++++++++++- ultraplot/tests/test_geographic.py | 24 ++-- ultraplot/tests/test_subplots.py | 7 +- 3 files changed, 196 insertions(+), 13 deletions(-) diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index baa4da58e..9d65cff98 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -671,6 +671,142 @@ def _apply_axis_sharing(self): self._lataxis.set_view_interval(*self._sharey._lataxis.get_view_interval()) self._lataxis.set_minor_locator(self._sharey._lataxis.get_minor_locator()) + def _apply_aspect_and_adjust_panels(self, *, tol=1e-9): + """ + Apply aspect and then align panels to the adjusted axes box. + + Notes + ----- + Cartopy and basemap use different tolerances when detecting whether + apply_aspect() actually changed the axes position. + """ + self.apply_aspect() + self._adjust_panel_positions(tol=tol) + + def _adjust_panel_positions(self, *, tol=1e-9): + """ + Adjust panel positions to align with the aspect-constrained main axes. + After apply_aspect() shrinks the main axes, panels should flank the actual + map boundaries rather than the full gridspec allocation. + """ + if not getattr(self, "_panel_dict", None): + return # no panels to adjust + + # Current (aspect-adjusted) position + main_pos = getattr(self, "_position", None) or self.get_position() + + # Subplot-spec position before apply_aspect(). This is the true "gridspec slot" + # and remains well-defined even if we temporarily modify axes positions. + try: + ss = self.get_subplotspec() + original_pos = ss.get_position(self.figure) if ss is not None else None + except Exception: + original_pos = None + if original_pos is None: + original_pos = getattr( + self, "_originalPosition", None + ) or self.get_position(original=True) + + # Only adjust if apply_aspect() actually changed the position (tolerance + # avoids float churn that can trigger unnecessary layout updates). + if ( + abs(main_pos.x0 - original_pos.x0) <= tol + and abs(main_pos.y0 - original_pos.y0) <= tol + and abs(main_pos.width - original_pos.width) <= tol + and abs(main_pos.height - original_pos.height) <= tol + ): + return + + # Map original -> adjusted coordinates (only along the "long" axis of the + # panel, so span overrides across subplot rows/cols are preserved). + sx = main_pos.width / original_pos.width if original_pos.width else 1.0 + sy = main_pos.height / original_pos.height if original_pos.height else 1.0 + ox0, oy0 = original_pos.x0, original_pos.y0 + ox1, oy1 = ( + original_pos.x0 + original_pos.width, + original_pos.y0 + original_pos.height, + ) + mx0, my0 = main_pos.x0, main_pos.y0 + + for side, panels in self._panel_dict.items(): + for panel in panels: + # Use the panel subplot-spec box as the baseline (not its current + # original position) to avoid accumulated adjustments. + try: + ss = panel.get_subplotspec() + panel_pos = ( + ss.get_position(panel.figure) if ss is not None else None + ) + except Exception: + panel_pos = None + if panel_pos is None: + panel_pos = panel.get_position(original=True) + px0, py0 = panel_pos.x0, panel_pos.y0 + px1, py1 = ( + panel_pos.x0 + panel_pos.width, + panel_pos.y0 + panel_pos.height, + ) + + # Use _set_position when available to avoid layoutbox side effects + # from public set_position() on newer matplotlib versions. + setter = getattr(panel, "_set_position", panel.set_position) + + if side == "left": + # Calculate original gap between panel and main axes + gap = original_pos.x0 - (panel_pos.x0 + panel_pos.width) + # Position panel to the left of the adjusted main axes + new_x0 = main_pos.x0 - panel_pos.width - gap + if py0 <= oy0 + tol and py1 >= oy1 - tol: + new_y0, new_h = my0, main_pos.height + else: + new_y0 = my0 + (panel_pos.y0 - oy0) * sy + new_h = panel_pos.height * sy + new_pos = [new_x0, new_y0, panel_pos.width, new_h] + elif side == "right": + # Calculate original gap + gap = panel_pos.x0 - (original_pos.x0 + original_pos.width) + # Position panel to the right of the adjusted main axes + new_x0 = main_pos.x0 + main_pos.width + gap + if py0 <= oy0 + tol and py1 >= oy1 - tol: + new_y0, new_h = my0, main_pos.height + else: + new_y0 = my0 + (panel_pos.y0 - oy0) * sy + new_h = panel_pos.height * sy + new_pos = [new_x0, new_y0, panel_pos.width, new_h] + elif side == "top": + # Calculate original gap + gap = panel_pos.y0 - (original_pos.y0 + original_pos.height) + # Position panel above the adjusted main axes + new_y0 = main_pos.y0 + main_pos.height + gap + if px0 <= ox0 + tol and px1 >= ox1 - tol: + new_x0, new_w = mx0, main_pos.width + else: + new_x0 = mx0 + (panel_pos.x0 - ox0) * sx + new_w = panel_pos.width * sx + new_pos = [new_x0, new_y0, new_w, panel_pos.height] + elif side == "bottom": + # Calculate original gap + gap = original_pos.y0 - (panel_pos.y0 + panel_pos.height) + # Position panel below the adjusted main axes + new_y0 = main_pos.y0 - panel_pos.height - gap + if px0 <= ox0 + tol and px1 >= ox1 - tol: + new_x0, new_w = mx0, main_pos.width + else: + new_x0 = mx0 + (panel_pos.x0 - ox0) * sx + new_w = panel_pos.width * sx + new_pos = [new_x0, new_y0, new_w, panel_pos.height] + else: + # Unknown side, skip adjustment + continue + + # Panels typically have aspect='auto', which causes matplotlib to + # reset their *active* position to their *original* position inside + # apply_aspect()/get_position(). Update both so the change persists. + try: + setter(new_pos, which="both") + except TypeError: # older matplotlib + setter(new_pos) + def _get_gridliner_labels( self, bottom=None, @@ -1296,6 +1432,7 @@ class _CartopyAxes(GeoAxes, _GeoAxes): _name = "cartopy" _name_aliases = ("geo", "geographic") # default 'geographic' axes _proj_class = Projection + _PANEL_TOL = 1e-9 _proj_north = ( pproj.NorthPolarStereo, pproj.NorthPolarGnomonic, @@ -1830,6 +1967,18 @@ def get_extent(self, crs=None): extent[:2] = [lon0 - 180, lon0 + 180] return extent + @override + def draw(self, renderer=None, *args, **kwargs): + """ + Override draw to adjust panel positions for cartopy axes. + + Cartopy's apply_aspect() can shrink the main axes to enforce the projection + aspect ratio. Panels occupy separate gridspec slots, so we reposition them + after the main axes has applied its aspect but before the panel axes are drawn. + """ + super().draw(renderer, *args, **kwargs) + self._adjust_panel_positions(tol=self._PANEL_TOL) + def get_tightbbox(self, renderer, *args, **kwargs): # Perform extra post-processing steps # For now this just draws the gridliners @@ -1847,8 +1996,9 @@ def get_tightbbox(self, renderer, *args, **kwargs): self.outline_patch._path = clipped_path self.background_patch._path = clipped_path - # Apply aspect - self.apply_aspect() + # Apply aspect, then ensure panels follow the aspect-constrained box. + self._apply_aspect_and_adjust_panels(tol=self._PANEL_TOL) + if _version_cartopy >= "0.23": gridliners = [ a for a in self.artists if isinstance(a, cgridliner.Gridliner) @@ -1924,6 +2074,7 @@ class _BasemapAxes(GeoAxes): "sinu", "vandg", ) + _PANEL_TOL = 1e-6 def __init__(self, *args, map_projection=None, **kwargs): """ @@ -1974,6 +2125,29 @@ def __init__(self, *args, map_projection=None, **kwargs): self._turnoff_tick_labels(self._lonlines_major) self._turnoff_tick_labels(self._latlines_major) + def get_tightbbox(self, renderer, *args, **kwargs): + """ + Get tight bounding box, adjusting panel positions after aspect is applied. + + This ensures panels are properly aligned when saving figures, as apply_aspect() + may be called during the rendering process. + """ + # Apply aspect ratio, then ensure panels follow the aspect-constrained box. + self._apply_aspect_and_adjust_panels(tol=self._PANEL_TOL) + + return super().get_tightbbox(renderer, *args, **kwargs) + + @override + def draw(self, renderer=None, *args, **kwargs): + """ + Override draw to adjust panel positions for basemap axes. + + Basemap projections also rely on apply_aspect() and can shrink the main axes; + panels must be repositioned to flank the visible map boundaries. + """ + super().draw(renderer, *args, **kwargs) + self._adjust_panel_positions(tol=self._PANEL_TOL) + def _turnoff_tick_labels(self, locator: mticker.Formatter): """ For GeoAxes with are dealing with a duality. Basemap axes behave differently than Cartopy axes and vice versa. UltraPlot abstracts away from these by providing GeoAxes. For basemap axes we need to turn off the tick labels as they will be handles by GeoAxis diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 94501fb37..9f1842d7b 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -1473,26 +1473,26 @@ def _check_boundary_labels(ax, expected_lon_labels, expected_lat_labels): # Check xlim/ylim are expanded beyond actual limits assert hasattr(gl, "xlim") and hasattr(gl, "ylim") - # Check longitude labels + # Check longitude labels - only verify the visible ones match expected lon_texts = [ label.get_text() for label in gl.bottom_label_artists if label.get_visible() ] - assert len(gl.bottom_label_artists) == len(expected_lon_labels), ( - f"Should have {len(expected_lon_labels)} longitude labels, " - f"got {len(gl.bottom_label_artists)}" + assert len(lon_texts) == len(expected_lon_labels), ( + f"Should have {len(expected_lon_labels)} visible longitude labels, " + f"got {len(lon_texts)}: {lon_texts}" ) for expected in expected_lon_labels: assert any( expected in text for text in lon_texts ), f"{expected} label should be visible, got: {lon_texts}" - # Check latitude labels + # Check latitude labels - only verify the visible ones match expected lat_texts = [ label.get_text() for label in gl.left_label_artists if label.get_visible() ] - assert len(gl.left_label_artists) == len(expected_lat_labels), ( - f"Should have {len(expected_lat_labels)} latitude labels, " - f"got {len(gl.left_label_artists)}" + assert len(lat_texts) == len(expected_lat_labels), ( + f"Should have {len(expected_lat_labels)} visible latitude labels, " + f"got {len(lat_texts)}: {lat_texts}" ) for expected in expected_lat_labels: assert any( @@ -1535,7 +1535,13 @@ def test_boundary_labels_negative_longitude(): grid=False, ) fig.canvas.draw() - _check_boundary_labels(ax[0], ["120°W", "90°W", "60°W"], ["20°N", "35°N", "50°N"]) + # Note: Cartopy hides the boundary label at 20°N due to it being exactly at the limit + # This is expected cartopy behavior with floating point precision at boundaries + _check_boundary_labels( + ax[0], + ["120°W", "90°W", "60°W"], + ["20°N", "35°N", "50°N"], + ) uplt.close(fig) diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 3ebe5f37d..86ed55a68 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -2,7 +2,10 @@ """ Test subplot layout. """ -import numpy as np, ultraplot as uplt, pytest +import numpy as np +import pytest + +import ultraplot as uplt @pytest.mark.mpl_image_compare @@ -207,7 +210,7 @@ def test_reference_aspect(test_case, refwidth, kwargs, setup_func, ref): # Apply auto layout fig.auto_layout() # Assert reference width accuracy - assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0]) + assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0], rtol=1e-3) return fig From 46d7a8bde8101784971d440e3f05f82da644164f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 Jan 2026 16:18:40 +1000 Subject: [PATCH 20/29] Bump the github-actions group with 2 updates (#444) Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [actions/download-artifact](https://github.com/actions/download-artifact). Updates `actions/upload-artifact` from 5 to 6 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v5...v6) Updates `actions/download-artifact` from 6 to 7 - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v6...v7) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: github-actions - dependency-name: actions/download-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major dependency-group: github-actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build-ultraplot.yml | 2 +- .github/workflows/publish-pypi.yml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-ultraplot.yml b/.github/workflows/build-ultraplot.yml index 7d6f1660a..7c3fb5252 100644 --- a/.github/workflows/build-ultraplot.yml +++ b/.github/workflows/build-ultraplot.yml @@ -98,7 +98,7 @@ jobs: # Return the html output of the comparison even if failed - name: Upload comparison failures if: always() - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: failed-comparisons-${{ inputs.python-version }}-${{ inputs.matplotlib-version }}-${{ github.sha }} path: results/* diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 63fb29714..4128d4275 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -54,7 +54,7 @@ jobs: shell: bash - name: Upload artifacts - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist/* @@ -73,7 +73,7 @@ jobs: contents: read steps: - name: Download artifacts - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist @@ -105,7 +105,7 @@ jobs: contents: read steps: - name: Download artifacts - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist From e64456f2292bff1aa21c1ee496947474f0ca11f9 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Sun, 4 Jan 2026 11:52:39 +1000 Subject: [PATCH 21/29] Fix dualx alignment on log axes (#443) * Apply dual-axis transform in data space * Add regression test for dualx on log axes --- ultraplot/scale.py | 12 +++++++++--- ultraplot/tests/test_axes.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/ultraplot/scale.py b/ultraplot/scale.py index d83d8f449..8fc2a05b8 100644 --- a/ultraplot/scale.py +++ b/ultraplot/scale.py @@ -11,8 +11,12 @@ import numpy.ma as ma from . import ticker as pticker -from .internals import ic # noqa: F401 -from .internals import _not_none, _version_mpl, warnings +from .internals import ( + _not_none, + _version_mpl, + ic, # noqa: F401 + warnings, +) __all__ = [ "CutoffScale", @@ -370,7 +374,9 @@ def __init__(self, transform=None, invert=False, parent_scale=None, **kwargs): kwsym["linthresh"] = inverse(kwsym["linthresh"]) parent_scale = SymmetricalLogScale(**kwsym) self.functions = (forward, inverse) - self._transform = parent_scale.get_transform() + FuncTransform(forward, inverse) + # Apply the function in data space, then parent scale (e.g., log). + # This ensures dual axes behave correctly when the parent is non-linear. + self._transform = FuncTransform(forward, inverse) + parent_scale.get_transform() # Apply default locators and formatters # NOTE: We pass these through contructor functions diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index 370f2c520..27b621c9f 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -4,6 +4,7 @@ """ import numpy as np import pytest + import ultraplot as uplt from ultraplot.internals.warnings import UltraPlotWarning @@ -130,6 +131,23 @@ def test_cartesian_format_all_units_types(): ax.format(**kwargs) +def test_dualx_log_transform_is_finite(): + """ + Ensure dualx transforms remain finite on log axes. + """ + fig, ax = uplt.subplots() + ax.set_xscale("log") + ax.set_xlim(0.1, 10) + sec = ax.dualx(lambda x: 1 / x) + fig.canvas.draw() + + ticks = sec.get_xticks() + assert ticks.size > 0 + xy = np.column_stack([ticks, np.zeros_like(ticks)]) + transformed = sec.transData.transform(xy) + assert np.isfinite(transformed).all() + + def test_axis_access(): # attempt to access the ax object 2d and linearly fig, ax = uplt.subplots(ncols=2, nrows=2) From fb4515387458b75687d59890569163c55f0b50ce Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Sun, 4 Jan 2026 11:56:12 +1000 Subject: [PATCH 22/29] Subset label sharing and implicit slice labels for axis groups (#440) * Add subset label sharing groups * Add subset label sharing tests * Adjust geo subset label tests * Limit implicit label sharing to subsets * Expand subset label sharing coverage * dedup logic --- ultraplot/axes/base.py | 29 +++- ultraplot/axes/cartesian.py | 28 ++-- ultraplot/axes/geo.py | 19 +++ ultraplot/figure.py | 219 +++++++++++++++++++++++++++++ ultraplot/gridspec.py | 39 +++++ ultraplot/tests/test_geographic.py | 39 +++++ ultraplot/tests/test_subplots.py | 200 ++++++++++++++++++++++++++ 7 files changed, 563 insertions(+), 10 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index a0e30f68b..01cc96d51 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3148,12 +3148,39 @@ def _update_share_labels(self, axes=None, target="x"): target : {'x', 'y'}, optional Which axis labels to share ('x' for x-axis, 'y' for y-axis) """ - if not axes: + if axes is False: + self.figure._clear_share_label_groups([self], target=target) + return + if axes is None or not len(list(axes)): return # Convert indices to actual axes objects if isinstance(axes[0], int): axes = [self.figure.axes[i] for i in axes] + axes = [ + ax._get_topmost_axes() if hasattr(ax, "_get_topmost_axes") else ax + for ax in axes + if ax is not None + ] + if len(axes) < 2: + return + # Preserve order while de-duplicating + seen = set() + unique = [] + for ax in axes: + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + unique.append(ax) + axes = unique + if len(axes) < 2: + return + + # Prefer figure-managed spanning labels when possible + if all(isinstance(ax, maxes.SubplotBase) for ax in axes): + self.figure._register_share_label_group(axes, target=target, source=self) + return # Get the center position of the axes group if box := self.get_center_of_axes(axes): diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 46685b5df..351823824 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -5,22 +5,27 @@ import copy import inspect +import matplotlib.axis as maxis import matplotlib.dates as mdates import matplotlib.ticker as mticker import numpy as np - from packaging import version from .. import constructor from .. import scale as pscale from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 -from ..internals import _not_none, _pop_rc, _version_mpl, docstring, labels, warnings -from . import plot, shared -import matplotlib.axis as maxis - +from ..internals import ( + _not_none, + _pop_rc, + _version_mpl, + docstring, + ic, # noqa: F401 + labels, + warnings, +) from ..utils import units +from . import plot, shared __all__ = ["CartesianAxes"] @@ -432,9 +437,14 @@ def _apply_axis_sharing_for_axis( # Handle axis label sharing (level > 0) if level > 0: - shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") - labels._transfer_label(axis.label, shared_axis_obj.label) - axis.label.set_visible(False) + if self.figure._is_share_label_group_member(self, axis_name): + pass + elif self.figure._is_share_label_group_member(shared_axis, axis_name): + axis.label.set_visible(False) + else: + shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") + labels._transfer_label(axis.label, shared_axis_obj.label) + axis.label.set_visible(False) # Handle tick label sharing (level > 2) if level > 2: diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 9d65cff98..267acb206 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -32,6 +32,7 @@ _version_cartopy, docstring, ic, # noqa: F401 + labels, warnings, ) from ..utils import units @@ -661,6 +662,24 @@ def _apply_axis_sharing(self): the leftmost and bottommost is the *figure* sharing level. """ + # Share axis labels + if self._sharex and self.figure._sharex >= 1: + if self.figure._is_share_label_group_member(self, "x"): + pass + elif self.figure._is_share_label_group_member(self._sharex, "x"): + self.xaxis.label.set_visible(False) + else: + labels._transfer_label(self.xaxis.label, self._sharex.xaxis.label) + self.xaxis.label.set_visible(False) + if self._sharey and self.figure._sharey >= 1: + if self.figure._is_share_label_group_member(self, "y"): + pass + elif self.figure._is_share_label_group_member(self._sharey, "y"): + self.yaxis.label.set_visible(False) + else: + labels._transfer_label(self.yaxis.label, self._sharey.yaxis.label) + self.yaxis.label.set_visible(False) + # Share interval x if self._sharex and self.figure._sharex >= 2: self._lonaxis.set_view_interval(*self._sharex._lonaxis.get_view_interval()) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index a0f74d201..5a4e5d1db 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -814,6 +814,7 @@ def __init__( self._supxlabel_dict = {} # an axes: label mapping self._supylabel_dict = {} # an axes: label mapping self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}} + self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups self._suptitle_pad = rc["suptitle.pad"] d = self._suplabel_props = {} # store the super label props d["left"] = {"va": "center", "ha": "right"} @@ -840,6 +841,7 @@ def draw(self, renderer): # we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars self._share_ticklabels(axis="x") self._share_ticklabels(axis="y") + self._apply_share_label_groups() super().draw(renderer) def _share_ticklabels(self, *, axis: str) -> None: @@ -1889,6 +1891,223 @@ def _align_axis_label(self, x): if span: self._update_axis_label(pos, axs) + # Apply explicit label-sharing groups for this axis + self._apply_share_label_groups(axis=x) + + def _register_share_label_group(self, axes, *, target, source=None): + """ + Register an explicit label-sharing group for a subset of axes. + """ + if not axes: + return + axes = list(axes) + axes = [ax for ax in axes if ax is not None and ax.figure is self] + if len(axes) < 2: + return + + # Preserve order while de-duplicating + seen = set() + unique = [] + for ax in axes: + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + unique.append(ax) + axes = unique + if len(axes) < 2: + return + + # Split by label side if mixed + axes_by_side = {} + if target == "x": + for ax in axes: + axes_by_side.setdefault(ax.xaxis.get_label_position(), []).append(ax) + else: + for ax in axes: + axes_by_side.setdefault(ax.yaxis.get_label_position(), []).append(ax) + if len(axes_by_side) > 1: + for side, side_axes in axes_by_side.items(): + side_source = source if source in side_axes else None + self._register_share_label_group_for_side( + side_axes, target=target, side=side, source=side_source + ) + return + + side, side_axes = next(iter(axes_by_side.items())) + self._register_share_label_group_for_side( + side_axes, target=target, side=side, source=source + ) + + def _register_share_label_group_for_side(self, axes, *, target, side, source=None): + """ + Register a single label-sharing group for a given label side. + """ + if not axes: + return + axes = [ax for ax in axes if ax is not None and ax.figure is self] + if len(axes) < 2: + return + + # Prefer label text from the source axes if available + label = None + if source in axes: + candidate = getattr(source, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + if label is None: + for ax in axes: + candidate = getattr(ax, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + break + + text = label.get_text() if label else "" + props = None + if label is not None: + props = { + "color": label.get_color(), + "fontproperties": label.get_font_properties(), + "rotation": label.get_rotation(), + "rotation_mode": label.get_rotation_mode(), + "ha": label.get_ha(), + "va": label.get_va(), + } + + group_key = tuple(sorted(id(ax) for ax in axes)) + groups = self._share_label_groups[target] + group = groups.get(group_key) + if group is None: + groups[group_key] = { + "axes": axes, + "side": side, + "text": text if text.strip() else "", + "props": props, + } + else: + group["axes"] = axes + group["side"] = side + if text.strip(): + group["text"] = text + group["props"] = props + + def _is_share_label_group_member(self, ax, axis): + """ + Return True if the axes belongs to any explicit label-sharing group. + """ + groups = self._share_label_groups.get(axis, {}) + return any(ax in group["axes"] for group in groups.values()) + + def _has_share_label_groups(self, axis): + """ + Return True if there are any explicit label-sharing groups for an axis. + """ + return bool(self._share_label_groups.get(axis, {})) + + def _clear_share_label_groups(self, axes=None, *, target=None): + """ + Clear explicit label-sharing groups, optionally filtered by axes. + """ + targets = ("x", "y") if target is None else (target,) + for axis in targets: + groups = self._share_label_groups.get(axis, {}) + if axes is None: + groups.clear() + continue + axes_set = {ax for ax in axes if ax is not None} + for key in list(groups): + if any(ax in axes_set for ax in groups[key]["axes"]): + del groups[key] + # Clear any existing spanning labels tied to these axes + if axis == "x": + for ax in axes_set: + if ax in self._supxlabel_dict: + self._supxlabel_dict[ax].set_text("") + else: + for ax in axes_set: + if ax in self._supylabel_dict: + self._supylabel_dict[ax].set_text("") + + def _apply_share_label_groups(self, axis=None): + """ + Apply explicit label-sharing groups, overriding default label sharing. + """ + + def _order_axes_for_side(axs, side): + if side in ("bottom", "top"): + key = ( + (lambda ax: ax._range_subplotspec("y")[1]) + if side == "bottom" + else (lambda ax: ax._range_subplotspec("y")[0]) + ) + reverse = side == "bottom" + else: + key = ( + (lambda ax: ax._range_subplotspec("x")[1]) + if side == "right" + else (lambda ax: ax._range_subplotspec("x")[0]) + ) + reverse = side == "right" + try: + return sorted(axs, key=key, reverse=reverse) + except Exception: + return list(axs) + + axes = (axis,) if axis in ("x", "y") else ("x", "y") + for target in axes: + groups = self._share_label_groups.get(target, {}) + for group in groups.values(): + axs = [ + ax for ax in group["axes"] if ax.figure is self and ax.get_visible() + ] + if len(axs) < 2: + continue + + side = group["side"] + ordered_axs = _order_axes_for_side(axs, side) + + # Refresh label text from any axis with non-empty text + label = None + for ax in ordered_axs: + candidate = getattr(ax, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + break + text = group["text"] + props = group["props"] + if label is not None: + text = label.get_text() + props = { + "color": label.get_color(), + "fontproperties": label.get_font_properties(), + "rotation": label.get_rotation(), + "rotation_mode": label.get_rotation_mode(), + "ha": label.get_ha(), + "va": label.get_va(), + } + group["text"] = text + group["props"] = props + + if not text: + continue + + try: + _, ax = self._get_align_coord( + side, ordered_axs, includepanels=self._includepanels + ) + except Exception: + continue + axlab = getattr(ax, f"{target}axis").label + axlab.set_text(text) + if props is not None: + axlab.set_color(props["color"]) + axlab.set_fontproperties(props["fontproperties"]) + axlab.set_rotation(props["rotation"]) + axlab.set_rotation_mode(props["rotation_mode"]) + axlab.set_ha(props["ha"]) + axlab.set_va(props["va"]) + self._update_axis_label(side, ordered_axs) + def _align_super_labels(self, side, renderer): """ Adjust the position of super labels. diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 63556ab0d..288f1abc4 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1749,7 +1749,46 @@ def format(self, **kwargs): ultraplot.figure.Figure.format ultraplot.config.Configurator.context """ + # Implicit label sharing for subset format calls + share_xlabels = kwargs.get("share_xlabels", None) + share_ylabels = kwargs.get("share_ylabels", None) + xlabel = kwargs.get("xlabel", None) + ylabel = kwargs.get("ylabel", None) + axes = [ax for ax in self if ax is not None] + all_axes = set(self.figure._subplot_dict.values()) + is_subset = bool(axes) and all_axes and set(axes) != all_axes + if len(self) > 1: + if share_xlabels is False: + self.figure._clear_share_label_groups(self, target="x") + if share_ylabels is False: + self.figure._clear_share_label_groups(self, target="y") + if is_subset and share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if is_subset and share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") self.figure.format(axs=self, **kwargs) + # Refresh groups after labels are set + if len(self) > 1: + if is_subset and share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if is_subset and share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") + + def share_labels(self, *, axis="x"): + """ + Register an explicit label-sharing group for this subset. + """ + if not self: + return self + axis = axis.lower() + if axis in ("x", "y"): + self.figure._register_share_label_group(self, target=axis) + elif axis in ("both", "all", "xy"): + self.figure._register_share_label_group(self, target="x") + self.figure._register_share_label_group(self, target="y") + else: + raise ValueError(f"Invalid axis={axis!r}. Options are 'x', 'y', or 'both'.") + return self @property def figure(self): diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 9f1842d7b..f1efed6ec 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -407,6 +407,45 @@ def test_geo_panel_share_flag_controls_membership(): assert ax2[0]._panel_sharex_group is False +def test_geo_subset_share_xlabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share="labels", span=False) + # GeoAxes.format does not accept xlabel/ylabel; set labels directly. + ax[0, 0].set_xlabel("Top-left X") + ax[0, 1].set_xlabel("Top-right X") + bottom = ax[1, :] + bottom[0].set_xlabel("Bottom-row X") + bottom.format(share_xlabels=list(bottom)) + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_geo_subset_share_xlabels_implicit(): + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share="labels", span=False) + ax[0, 0].set_xlabel("Top-left X") + ax[0, 1].set_xlabel("Top-right X") + bottom = ax[1, :] + bottom[0].set_xlabel("Bottom-row X") + bottom.share_labels(axis="x") + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + def test_geo_non_rectilinear_right_panel_forces_no_share_and_warns(): """ Non-rectilinear Geo projections should not allow panel sharing; adding a right panel diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 86ed55a68..eb42c79fc 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -258,6 +258,206 @@ def test_axis_sharing(share): return fig +def test_subset_share_xlabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(xlabel="Top-left X") + ax[0, 1].format(xlabel="Top-right X") + bottom = ax[1, :] + bottom[0].format(xlabel="Bottom-row X", share_xlabels=list(bottom)) + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(xlabel="Top-left X") + ax[0, 1].format(xlabel="Top-right X") + bottom = ax[1, :] + bottom.format(xlabel="Bottom-row X") + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_ylabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(ylabel="Left-top Y") + ax[1, 0].format(ylabel="Left-bottom Y") + right = ax[:, 1] + right[0].format(ylabel="Right-column Y", share_ylabels=list(right)) + + fig.canvas.draw() + + assert ax[0, 0].yaxis.get_label().get_visible() + assert ax[0, 0].get_ylabel() == "Left-top Y" + assert ax[1, 0].yaxis.get_label().get_visible() + assert ax[1, 0].get_ylabel() == "Left-bottom Y" + assert right[0].get_ylabel().strip() == "" + assert right[1].get_ylabel().strip() == "" + assert any( + lab.get_text() == "Right-column Y" for lab in fig._supylabel_dict.values() + ) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit_column(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right.format(xlabel="Right-column X") + + fig.canvas.draw() + + assert ax[0, 1].get_xlabel().strip() == "" + assert ax[1, 1].get_xlabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supxlabel_dict.items() + if lab.get_text() == "Right-column X" + ] + assert label_axes and label_axes[0] is ax[1, 1] + + uplt.close(fig) + + +def test_subset_share_ylabels_implicit_row(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + top = ax[0, :] + top.format(ylabel="Top-row Y") + + fig.canvas.draw() + + assert ax[0, 0].get_ylabel().strip() == "" + assert ax[0, 1].get_ylabel().strip() == "" + label_axes = [ + axi for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y" + ] + assert label_axes and label_axes[0] is ax[0, 0] + + uplt.close(fig) + + +def test_subset_share_xlabels_clear(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + bottom = ax[1, :] + bottom.format(xlabel="Shared") + + fig.canvas.draw() + assert any(lab.get_text() == "Shared" for lab in fig._supxlabel_dict.values()) + + bottom.format(share_xlabels=False, xlabel="Unshared") + fig.canvas.draw() + + assert not any(lab.get_text() == "Shared" for lab in fig._supxlabel_dict.values()) + assert not any(lab.get_text() == "Unshared" for lab in fig._supxlabel_dict.values()) + assert bottom[0].get_xlabel() == "Unshared" + assert bottom[1].get_xlabel() == "Unshared" + + uplt.close(fig) + + +def test_subset_share_labels_method_both(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right[0].set_xlabel("Right-column X") + right[0].set_ylabel("Right-column Y") + right.share_labels(axis="both") + + fig.canvas.draw() + + assert right[0].get_xlabel().strip() == "" + assert right[1].get_xlabel().strip() == "" + assert right[0].get_ylabel().strip() == "" + assert right[1].get_ylabel().strip() == "" + assert any( + lab.get_text() == "Right-column X" for lab in fig._supxlabel_dict.values() + ) + assert any( + lab.get_text() == "Right-column Y" for lab in fig._supylabel_dict.values() + ) + + uplt.close(fig) + + +def test_subset_share_labels_invalid_axis(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + with pytest.raises(ValueError): + ax[:, 1].share_labels(axis="nope") + + uplt.close(fig) + + +def test_subset_share_xlabels_mixed_sides(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + ax[0, :].format(xlabelloc="top", share_xlabels=False) + ax[1, :].format(xlabelloc="bottom", share_xlabels=False) + ax[0, 0].set_xlabel("Top X") + ax[0, 1].set_xlabel("Top X") + ax[1, 0].set_xlabel("Bottom X") + ax[1, 1].set_xlabel("Bottom X") + ax[0, 0].format(share_xlabels=list(ax)) + + fig.canvas.draw() + + assert any(lab.get_text() == "Top X" for lab in fig._supxlabel_dict.values()) + assert any(lab.get_text() == "Bottom X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit_column_top(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right.format(xlabel="Right-column X (top)", xlabelloc="top") + + fig.canvas.draw() + + assert ax[0, 1].get_xlabel().strip() == "" + assert ax[1, 1].get_xlabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supxlabel_dict.items() + if lab.get_text() == "Right-column X (top)" + ] + assert label_axes and label_axes[0] is ax[0, 1] + + uplt.close(fig) + + +def test_subset_share_ylabels_implicit_row_right(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + top = ax[0, :] + top.format(ylabel="Top-row Y (right)", ylabelloc="right") + + fig.canvas.draw() + + assert ax[0, 0].get_ylabel().strip() == "" + assert ax[0, 1].get_ylabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supylabel_dict.items() + if lab.get_text() == "Top-row Y (right)" + ] + assert label_axes and label_axes[0] is ax[0, 1] + + uplt.close(fig) + + @pytest.mark.parametrize( "layout", [ From 5895f8a2785d4a6efbf4e16a795341f92c4e7e0c Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Sun, 4 Jan 2026 11:57:31 +1000 Subject: [PATCH 23/29] Preserve log formatter when setting log scales (#437) --- ultraplot/axes/cartesian.py | 51 ++++++++++++++++++++++++++++++++++++ ultraplot/tests/test_plot.py | 27 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 351823824..c115dc45f 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -789,6 +789,26 @@ def _sharey_setup(self, sharey, *, labels=True, limits=True): if level > 1 and limits: self._sharey_limits(sharey) + def _apply_log_formatter_on_scale(self, s): + """ + Enforce log formatter when log scale is set and rc is enabled. + """ + if not rc.find("formatter.log", context=True): + return + if getattr(self, f"get_{s}scale")() != "log": + return + self._update_formatter(s, "log") + + def set_xscale(self, value, **kwargs): + result = super().set_xscale(value, **kwargs) + self._apply_log_formatter_on_scale("x") + return result + + def set_yscale(self, value, **kwargs): + result = super().set_yscale(value, **kwargs) + self._apply_log_formatter_on_scale("y") + return result + def _update_formatter( self, s, @@ -1399,6 +1419,7 @@ def format( # WARNING: Changing axis scale also changes default locators # and formatters, and restricts possible range of axis limits, # so critical to do it first. + scale_requested = scale is not None if scale is not None: scale = constructor.Scale(scale, **scale_kw) getattr(self, f"set_{s}scale")(scale) @@ -1490,10 +1511,40 @@ def format( tickrange=tickrange, wraprange=wraprange, ) + if ( + scale_requested + and formatter is None + and not formatter_kw + and tickrange is None + and wraprange is None + and rc.find("formatter.log", context=True) + and getattr(self, f"get_{s}scale")() == "log" + ): + self._update_formatter(s, "log") # Ensure ticks are within axis bounds self._fix_ticks(s, fixticks=fixticks) + if rc.find("formatter.log", context=True): + if ( + xscale is not None + and xformatter is None + and not xformatter_kw + and xtickrange is None + and xwraprange is None + and self.get_xscale() == "log" + ): + self._update_formatter("x", "log") + if ( + yscale is not None + and yformatter is None + and not yformatter_kw + and ytickrange is None + and ywraprange is None + and self.get_yscale() == "log" + ): + self._update_formatter("y", "log") + # Parent format method if aspect is not None: self.set_aspect(aspect) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index fb54d191a..1bcb69684 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -361,6 +361,33 @@ def reset(ax): uplt.close(fig) +def test_format_log_scale_preserves_log_formatter(): + """ + Test that setting a log scale preserves the log formatter when enabled. + """ + x = np.linspace(1, 1e6, 10) + log_formatter = uplt.constructor.Formatter("log") + log_formatter_type = type(log_formatter) + + with uplt.rc.context({"formatter.log": True}): + fig, ax = uplt.subplots() + ax.plot(x, x) + ax.format(yscale="log") + assert isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + ax.set_yscale("log") + assert isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + + with uplt.rc.context({"formatter.log": False}): + fig, ax = uplt.subplots() + ax.plot(x, x) + ax.format(yscale="log") + assert not isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + ax.set_yscale("log") + assert not isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + + uplt.close(fig) + + def test_shading_pcolor(rng): """ Pcolormesh by default adjusts the plot by From f55a7a960bf6df0cf7e3d6a68d452cff90af5044 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 31 Dec 2025 20:05:22 +1000 Subject: [PATCH 24/29] Shrink title to avoid abc overlap --- ultraplot/axes/base.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 01cc96d51..968435739 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3042,7 +3042,9 @@ def _update_title_position(self, renderer): # Offset title away from a-b-c label atext, ttext = aobj.get_text(), tobj.get_text() awidth = twidth = 0 - pad = (abcpad / 72) / self._get_size_inches()[0] + width_inches = self._get_size_inches()[0] + pad = (abcpad / 72) / width_inches + abc_pad = (self._abc_pad / 72) / width_inches ha = aobj.get_ha() # Get dimensions of non-empty elements @@ -3059,6 +3061,30 @@ def _update_title_position(self, renderer): .width ) + # Shrink the title font if both texts share a location and would overflow + if atext and ttext and self._abc_loc == self._title_loc and twidth > 0: + scale = 1 + base_x = tobj.get_position()[0] + if ha == "left": + available = 1 - (base_x + awidth + pad) + if available < twidth and available > 0: + scale = available / twidth + elif ha == "right": + available = base_x + abc_pad - pad - awidth + if available < twidth and available > 0: + scale = available / twidth + elif ha == "center": + # Conservative fit for centered titles sharing the abc location + left_room = base_x - 0.5 * (awidth + pad) + right_room = 1 - (base_x + 0.5 * (awidth + pad)) + max_room = min(left_room, right_room) + if max_room < twidth / 2 and max_room > 0: + scale = (2 * max_room) / twidth + + if scale < 1: + tobj.set_fontsize(tobj.get_fontsize() * scale) + twidth *= scale + # Calculate offsets based on alignment and content aoffset = toffset = 0 if atext and ttext: From 0df473fc527c51bf7758b38458204e3428944463 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 31 Dec 2025 20:14:50 +1000 Subject: [PATCH 25/29] Skip title auto-scaling when fontsize is set --- ultraplot/axes/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 968435739..1929e08c1 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -2986,6 +2986,8 @@ def _update_title(self, loc, title=None, **kwargs): kw["text"] = title[self.number - 1] else: raise ValueError(f"Invalid title {title!r}. Must be string(s).") + if any(key in kwargs for key in ("size", "fontsize")): + self._title_dict[loc]._ultraplot_manual_size = True kw.update(kwargs) self._title_dict[loc].update(kw) @@ -3062,7 +3064,13 @@ def _update_title_position(self, renderer): ) # Shrink the title font if both texts share a location and would overflow - if atext and ttext and self._abc_loc == self._title_loc and twidth > 0: + if ( + atext + and ttext + and self._abc_loc == self._title_loc + and twidth > 0 + and not getattr(tobj, "_ultraplot_manual_size", False) + ): scale = 1 base_x = tobj.get_position()[0] if ha == "left": From 9d9317e5725d6812b797b3ab31cf25a992e5ec21 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Thu, 1 Jan 2026 14:21:27 +1000 Subject: [PATCH 26/29] Fix title overlap tests and zero-size axes draw --- ultraplot/axes/base.py | 2 ++ ultraplot/tests/test_axes.py | 21 ++++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 1929e08c1..9336721e7 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3000,6 +3000,8 @@ def _update_title_position(self, renderer): # NOTE: Critical to do this every time in case padding changes or # we added or removed an a-b-c label in the same position as a title width, height = self._get_size_inches() + if width <= 0 or height <= 0: + return x_pad = self._title_pad / (72 * width) y_pad = self._title_pad / (72 * height) for loc, obj in self._title_dict.items(): diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index 27b621c9f..21e6230cc 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -141,11 +141,22 @@ def test_dualx_log_transform_is_finite(): sec = ax.dualx(lambda x: 1 / x) fig.canvas.draw() - ticks = sec.get_xticks() - assert ticks.size > 0 - xy = np.column_stack([ticks, np.zeros_like(ticks)]) - transformed = sec.transData.transform(xy) - assert np.isfinite(transformed).all() + +def test_title_manual_size_ignores_auto_shrink(): + """ + Ensure explicit title sizes bypass auto-scaling. + """ + fig, axs = uplt.subplots(figsize=(2, 2)) + axs.format( + abc=True, + title="X" * 200, + titleloc="left", + abcloc="left", + title_kw={"size": 20}, + ) + title_obj = axs[0]._title_dict["left"] + fig.canvas.draw() + assert title_obj.get_fontsize() == 20 def test_axis_access(): From b3670a799114c828047085930f1972889706efd1 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Thu, 1 Jan 2026 14:28:19 +1000 Subject: [PATCH 27/29] Shrink titles when abc overlaps across locations --- ultraplot/axes/base.py | 38 ++++++++++++++++++++++++++++++++++++ ultraplot/tests/test_axes.py | 12 ++++++++++++ 2 files changed, 50 insertions(+) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 9336721e7..61c3ec02c 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3116,6 +3116,44 @@ def _update_title_position(self, renderer): if ttext: tobj.set_x(tobj.get_position()[0] + toffset) + # Shrink title if it overlaps the abc label at a different location + if ( + atext + and self._abc_loc != self._title_loc + and not getattr( + self._title_dict[self._title_loc], "_ultraplot_manual_size", False + ) + ): + title_obj = self._title_dict[self._title_loc] + title_text = title_obj.get_text() + if title_text: + abc_bbox = aobj.get_window_extent(renderer).transformed( + self.transAxes.inverted() + ) + title_bbox = title_obj.get_window_extent(renderer).transformed( + self.transAxes.inverted() + ) + ax0, ax1 = abc_bbox.x0, abc_bbox.x1 + tx0, tx1 = title_bbox.x0, title_bbox.x1 + if tx0 < ax1 + pad and tx1 > ax0 - pad: + base_x = title_obj.get_position()[0] + ha = title_obj.get_ha() + max_width = 0 + if ha == "left": + if base_x <= ax0 - pad: + max_width = (ax0 - pad) - base_x + elif ha == "right": + if base_x >= ax1 + pad: + max_width = base_x - (ax1 + pad) + elif ha == "center": + if base_x >= ax1 + pad: + max_width = 2 * (base_x - (ax1 + pad)) + elif base_x <= ax0 - pad: + max_width = 2 * ((ax0 - pad) - base_x) + if 0 < max_width < title_bbox.width: + scale = max_width / title_bbox.width + title_obj.set_fontsize(title_obj.get_fontsize() * scale) + def _update_super_title(self, suptitle=None, **kwargs): """ Update the figure super title. diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index 21e6230cc..408f1fa7d 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -159,6 +159,18 @@ def test_title_manual_size_ignores_auto_shrink(): assert title_obj.get_fontsize() == 20 +def test_title_shrinks_when_abc_overlaps_different_loc(): + """ + Ensure long titles shrink when overlapping abc at a different location. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 200, titleloc="center", abcloc="left") + title_obj = axs[0]._title_dict["center"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + def test_axis_access(): # attempt to access the ax object 2d and linearly fig, ax = uplt.subplots(ncols=2, nrows=2) From d68a9ddff9c5124856e3bf4f7e7734c606aefab9 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Sun, 4 Jan 2026 12:25:40 +1000 Subject: [PATCH 28/29] update tests --- ultraplot/tests/test_axes.py | 73 ++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index 408f1fa7d..610298bf1 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -171,6 +171,79 @@ def test_title_shrinks_when_abc_overlaps_different_loc(): assert title_obj.get_fontsize() < original_size +def test_title_shrinks_right_aligned_same_location(): + """ + Test that right-aligned titles shrink when they would overflow with abc label. + """ + fig, axs = uplt.subplots(figsize=(2, 2)) + axs.format(abc=True, title="X" * 100, titleloc="right", abcloc="right") + title_obj = axs[0]._title_dict["right"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_shrinks_centered_same_location(): + """ + Test that centered titles shrink when they would overflow with abc label. + """ + fig, axs = uplt.subplots(figsize=(2, 2)) + axs.format(abc=True, title="X" * 150, titleloc="center", abcloc="center") + title_obj = axs[0]._title_dict["center"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_shrinks_right_aligned_different_location(): + """ + Test that right-aligned titles shrink when overlapping abc at different location. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 100, titleloc="right", abcloc="left") + title_obj = axs[0]._title_dict["right"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_shrinks_left_aligned_different_location(): + """ + Test that left-aligned titles shrink when overlapping abc at different location. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 100, titleloc="left", abcloc="right") + title_obj = axs[0]._title_dict["left"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_no_shrink_when_no_overlap(): + """ + Test that titles don't shrink when there's no overlap with abc label. + """ + fig, axs = uplt.subplots(figsize=(4, 2)) + axs.format(abc=True, title="Short Title", titleloc="left", abcloc="right") + title_obj = axs[0]._title_dict["left"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() == original_size + + +def test_title_shrinks_centered_left_of_abc(): + """ + Test that centered titles shrink when they are to the left of abc label. + This covers the specific case where base_x <= ax0 - pad for centered titles. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 100, titleloc="center", abcloc="right") + title_obj = axs[0]._title_dict["center"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + def test_axis_access(): # attempt to access the ax object 2d and linearly fig, ax = uplt.subplots(ncols=2, nrows=2) From 7e2c0a9258d8b09a0a35bbd9db2548e410bfde57 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Sun, 4 Jan 2026 13:17:19 +1000 Subject: [PATCH 29/29] add tests --- ultraplot/figure.py | 8 ++++++ ultraplot/tests/test_legend.py | 46 ++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 5a4e5d1db..b81ef2cb7 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -2707,7 +2707,15 @@ def legend( # Extract a single axes from array if span is provided # Otherwise, pass the array as-is for normal legend behavior + # Automatically collect handles and labels from spanned axes if not provided if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): + # Auto-collect handles and labels if not explicitly provided + if handles is None and labels is None: + handles, labels = [], [] + for axi in ax: + h, l = axi.get_legend_handles_labels() + handles.extend(h) + labels.extend(l) try: ax_single = next(iter(ax)) except (TypeError, StopIteration): diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 48a40a678..6b984a55e 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -483,3 +483,49 @@ def test_legend_multiple_sides_with_span(): assert leg_top is not None assert leg_right is not None assert leg_left is not None + + +def test_legend_auto_collect_handles_labels_with_span(): + """Test automatic collection of handles and labels from multiple axes with span parameters.""" + + fig, axs = uplt.subplots(nrows=2, ncols=2) + + # Create different plots in each subplot with labels + axs[0, 0].plot([0, 1], [0, 1], label="line1") + axs[0, 1].plot([0, 1], [1, 0], label="line2") + axs[1, 0].scatter([0.5], [0.5], label="point1") + axs[1, 1].scatter([0.5], [0.5], label="point2") + + # Test automatic collection with span parameter (no explicit handles/labels) + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created and contains all handles/labels from both axes + assert leg is not None + assert len(leg.get_texts()) == 2 # Should have 2 labels (line1, line2) + + # Test with rows parameter + leg2 = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + assert leg2 is not None + assert len(leg2.get_texts()) == 2 # Should have 2 labels (line1, point1) + + +def test_legend_explicit_handles_labels_override_auto_collection(): + """Test that explicit handles/labels override auto-collection.""" + + fig, axs = uplt.subplots(nrows=1, ncols=2) + + # Create plots with labels + (h1,) = axs[0].plot([0, 1], [0, 1], label="auto_label1") + (h2,) = axs[1].plot([0, 1], [1, 0], label="auto_label2") + + # Test with explicit handles/labels (should override auto-collection) + custom_handles = [h1] + custom_labels = ["custom_label"] + leg = fig.legend( + ax=axs, span=(1, 2), loc="bottom", handles=custom_handles, labels=custom_labels + ) + + # Verify legend uses explicit handles/labels, not auto-collected ones + assert leg is not None + assert len(leg.get_texts()) == 1 + assert leg.get_texts()[0].get_text() == "custom_label"