diff --git a/.github/workflows/build-ultraplot.yml b/.github/workflows/build-ultraplot.yml index ab19ab15c..7c3fb5252 100644 --- a/.github/workflows/build-ultraplot.yml +++ b/.github/workflows/build-ultraplot.yml @@ -22,7 +22,7 @@ jobs: run: shell: bash -el {0} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 @@ -58,7 +58,7 @@ jobs: run: shell: bash -el {0} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: mamba-org/setup-micromamba@v2.0.7 with: @@ -98,7 +98,7 @@ jobs: # Return the html output of the comparison even if failed - name: Upload comparison failures if: always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: failed-comparisons-${{ inputs.python-version }}-${{ inputs.matplotlib-version }}-${{ github.sha }} path: results/* diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 01d9c856f..2cc8b1b68 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -11,7 +11,7 @@ jobs: outputs: run: ${{ (github.event_name == 'push' && github.ref_name == 'main') && 'true' || steps.filter.outputs.python }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: filter with: @@ -28,7 +28,7 @@ jobs: python-versions: ${{ steps.set-versions.outputs.python-versions }} matplotlib-versions: ${{ steps.set-versions.outputs.matplotlib-versions }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 1eda57ccb..4128d4275 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -15,7 +15,7 @@ jobs: name: Build packages runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 @@ -54,7 +54,7 @@ jobs: shell: bash - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist/* @@ -73,7 +73,7 @@ jobs: contents: read steps: - name: Download artifacts - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v7 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist @@ -105,7 +105,7 @@ jobs: contents: read steps: - name: Download artifacts - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v7 with: name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} path: dist diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b19a691d..eae4604a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,6 +11,6 @@ ci: repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.9.0 + rev: 25.11.0 hooks: - id: black diff --git a/docs/2dplots.py b/docs/2dplots.py index 3f27b7b56..edc22e97c 100644 --- a/docs/2dplots.py +++ b/docs/2dplots.py @@ -77,9 +77,10 @@ # setting and the :ref:`user guide `). # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) x = y = np.array([-10, -5, 0, 5, 10]) @@ -110,9 +111,10 @@ axs[3].contourf(xedges, yedges, data) # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data cmap = "turku_r" state = np.random.RandomState(51423) @@ -184,9 +186,9 @@ # `~pint.UnitRegistry.setup_matplotlib` so that the axes become unit-aware. # %% -import xarray as xr import numpy as np import pandas as pd +import xarray as xr # DataArray state = np.random.RandomState(51423) @@ -261,13 +263,14 @@ # ``diverging=True``, ``cyclic=True``, or ``qualitative=True`` to any plotting # command. If the colormap type is not explicitly specified, `sequential` is # used with the default linear normalizer when data is strictly positive -# or negative, and `diverging` is used with the :ref:`diverging normalizer ` +# or negative, and `diverging` is used with the :ref:`diverging normalizer ` # when the data limits or colormap levels cross zero (see :ref:`below `). # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data N = 18 state = np.random.RandomState(51423) @@ -294,9 +297,10 @@ uplt.rc.reset() # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data N = 20 state = np.random.RandomState(51423) @@ -322,9 +326,10 @@ colorbar="b", ) -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data N = 20 state = np.random.RandomState(51423) @@ -347,7 +352,7 @@ # Special normalizers # ------------------- # -# UltraPlot includes two new :ref:`"continuous" normalizers `. The +# UltraPlot includes two new :ref:`"continuous" normalizers `. The # `~ultraplot.colors.SegmentedNorm` normalizer provides even color gradations with respect # to index for an arbitrary monotonically increasing or decreasing list of levels. This # is automatically applied if you pass unevenly spaced `levels` to a plotting command, @@ -372,9 +377,10 @@ # affect the interpretation of different datasets. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data = 11 ** (2 * state.rand(20, 20).cumsum(axis=0) / 7) @@ -395,9 +401,10 @@ ) ax.format(title=norm.title() + " normalizer") # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data1 = (state.rand(20, 20) - 0.485).cumsum(axis=1).cumsum(axis=0) @@ -434,7 +441,7 @@ # commands (e.g., :func:`~ultraplot.axes.PlotAxes.contourf`, :func:`~ultraplot.axes.PlotAxes.pcolor`). # This is analogous to `matplotlib.colors.BoundaryNorm`, except # `~ultraplot.colors.DiscreteNorm` can be paired with arbitrary -# continuous normalizers specified by `norm` (see :ref:`above `). +# continuous normalizers specified by `norm` (see :ref:`above `). # Discrete color levels can help readers discern exact numeric values and # tend to reveal qualitative structure in the data. `~ultraplot.colors.DiscreteNorm` # also repairs the colormap end-colors by ensuring the following conditions are met: @@ -463,9 +470,10 @@ # the zero level (useful for single-color :func:`~ultraplot.axes.PlotAxes.contour` plots). # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data = 10 + state.normal(0, 1, size=(33, 33)).cumsum(axis=0).cumsum(axis=1) @@ -485,9 +493,10 @@ axs[2].format(title="Imshow plot\ndiscrete=False (default)", yformatter="auto") # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + # Sample data state = np.random.RandomState(51423) data = (20 * (state.rand(20, 20) - 0.4).cumsum(axis=0).cumsum(axis=1)) % 360 @@ -547,7 +556,7 @@ # the 2D :class:`~ultraplot.axes.PlotAxes` commands will apply the diverging colormap # :rc:`cmap.diverging` (rather than :rc:`cmap.sequential`) and the diverging # normalizer `~ultraplot.colors.DivergingNorm` (rather than :class:`~matplotlib.colors.Normalize` -# -- see :ref:`above `) if the following conditions are met: +# -- see :ref:`above `) if the following conditions are met: # # #. If discrete levels are enabled (see :ref:`above `) and the # level list includes at least 2 negative and 2 positive values. @@ -560,9 +569,10 @@ # setting :rcraw:`cmap.autodiverging` to ``False``. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + N = 20 state = np.random.RandomState(51423) data = N * 2 + (state.rand(N, N) - 0.45).cumsum(axis=0).cumsum(axis=1) * 10 @@ -605,9 +615,10 @@ # command documentation for details. # %% -import ultraplot as uplt -import pandas as pd import numpy as np +import pandas as pd + +import ultraplot as uplt # Sample data state = np.random.RandomState(51423) @@ -663,10 +674,11 @@ # `~ultraplot.axes.CartesianAxes`. # %% -import ultraplot as uplt import numpy as np import pandas as pd +import ultraplot as uplt + # Covariance data state = np.random.RandomState(51423) data = state.normal(size=(10, 10)).cumsum(axis=0) diff --git a/docs/colorbars_legends.py b/docs/colorbars_legends.py index 2b2d58bca..10a4099c8 100644 --- a/docs/colorbars_legends.py +++ b/docs/colorbars_legends.py @@ -78,9 +78,10 @@ # complex arrangements of subplots, colorbars, and legends. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + state = np.random.RandomState(51423) fig = uplt.figure(share=False, refwidth=2.3) @@ -183,9 +184,10 @@ ) # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + N = 10 state = np.random.RandomState(51423) fig, axs = uplt.subplots( @@ -232,9 +234,10 @@ # and the tight layout padding can be controlled with the `pad` keyword. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + state = np.random.RandomState(51423) fig, axs = uplt.subplots(ncols=3, nrows=3, refwidth=1.4) for ax in axs: @@ -254,9 +257,10 @@ fig.colorbar(m, label="colorbar with length <1", ticks=0.1, loc="r", length=0.7) # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + state = np.random.RandomState(51423) fig, axs = uplt.subplots( ncols=2, nrows=2, order="F", refwidth=1.7, wspace=2.5, share=False @@ -299,7 +303,7 @@ # will build the required `~matplotlib.cm.ScalarMappable` on-the-fly. Lists # of :class:`~matplotlib.artist.Artists`\ s are used when you use the `colorbar` # keyword with :ref:`1D commands ` like :func:`~ultraplot.axes.PlotAxes.plot`. -# * The associated :ref:`colormap normalizer ` can be specified with the +# * The associated :ref:`colormap normalizer ` can be specified with the # `vmin`, `vmax`, `norm`, and `norm_kw` keywords. The `~ultraplot.colors.DiscreteNorm` # levels can be specified with `values`, or UltraPlot will infer them from the # :class:`~matplotlib.artist.Artist` labels (non-numeric labels will be applied to @@ -332,9 +336,10 @@ # See :func:`~ultraplot.axes.Axes.colorbar` for details. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + fig = uplt.figure(share=False, refwidth=2) # Colorbars from lines @@ -427,9 +432,10 @@ # (or use the `handle_kw` keyword). See `ultraplot.axes.Axes.legend` for details. # %% -import ultraplot as uplt import numpy as np +import ultraplot as uplt + uplt.rc.cycle = "538" fig, axs = uplt.subplots(ncols=2, span=False, share="labels", refwidth=2.3) labels = ["a", "bb", "ccc", "dddd", "eeeee"] diff --git a/docs/why.rst b/docs/why.rst index 392a5616d..74fc644c4 100644 --- a/docs/why.rst +++ b/docs/why.rst @@ -499,9 +499,9 @@ like :func:`~ultraplot.axes.PlotAxes.pcolor` and :func:`~ultraplot.axes.PlotAxes plots. This can be disabled by setting :rcraw:`cmap.discrete` to ``False`` or by passing ``discrete=False`` to :class:`~ultraplot.axes.PlotAxes` commands. * The :class:`~ultraplot.colors.DivergingNorm` normalizer is perfect for data with a - :ref:`natural midpoint ` and offers both "fair" and "unfair" scaling. + :ref:`natural midpoint ` and offers both "fair" and "unfair" scaling. The :class:`~ultraplot.colors.SegmentedNorm` normalizer can generate - uneven color gradations useful for :ref:`unusual data distributions `. + uneven color gradations useful for :ref:`unusual data distributions `. * The :func:`~ultraplot.axes.PlotAxes.heatmap` command invokes :func:`~ultraplot.axes.PlotAxes.pcolormesh` then applies an `equal axes apect ratio `__, @@ -882,7 +882,7 @@ Limitation ---------- Matplotlib :obj:`~matplotlib.rcParams` can be changed persistently by placing -`matplotlibrc` :ref:`ug_mplrc` files in the same directory as your python script. +ref:`matplotlibrc ` files in the same directory as your python script. But it can be difficult to design and store your own colormaps and color cycles for future use. It is also difficult to get matplotlib to use custom ``.ttf`` and ``.otf`` font files, which may be desirable when you are working on diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index b7e6631be..61c3ec02c 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -6,10 +6,11 @@ import copy import inspect import re +import sys import types -from numbers import Integral, Number -from typing import Union, Iterable, MutableMapping, Optional, Tuple from collections.abc import Iterable as IterableType +from numbers import Integral, Number +from typing import Iterable, MutableMapping, Optional, Tuple, Union try: # From python 3.12 @@ -34,12 +35,11 @@ from matplotlib import cbook from packaging import version -from .. import legend as plegend from .. import colors as pcolors from .. import constructor +from .. import legend as plegend from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 from ..internals import ( _kwargs_to_args, _not_none, @@ -51,6 +51,7 @@ _version_mpl, docstring, guides, + ic, # noqa: F401 labels, rcsetup, warnings, @@ -700,7 +701,52 @@ def __call__(self, ax, renderer): # noqa: U100 return bbox -class Axes(maxes.Axes): +class _ExternalModeMixin: + """ + Mixin providing explicit external-mode control and a context manager. + """ + + def set_external(self, value=True): + """ + Set explicit external-mode override for this axes. + + value: + - True: force external behavior (defer on-the-fly guides, etc.) + - False: force UltraPlot behavior + """ + if value not in (True, False): + raise ValueError("set_external expects True or False") + setattr(self, "_integration_external", value) + return self + + class _ExternalContext: + def __init__(self, ax, value=True): + self._ax = ax + self._value = True if value is None else value + self._prev = getattr(ax, "_integration_external", None) + + def __enter__(self): + self._ax._integration_external = self._value + return self._ax + + def __exit__(self, exc_type, exc, tb): + self._ax._integration_external = self._prev + + def external(self, value=True): + """ + Context manager toggling external mode during the block. + """ + return _ExternalModeMixin._ExternalContext(self, value) + + def _in_external_context(self): + """ + Return True if UltraPlot helper behaviors should be suppressed. + """ + mode = getattr(self, "_integration_external", None) + return mode is True + + +class Axes(_ExternalModeMixin, maxes.Axes): """ The lowest-level `~matplotlib.axes.Axes` subclass used by ultraplot. Implements basic universal features. @@ -822,6 +868,7 @@ def __init__(self, *args, **kwargs): self._panel_sharey_group = False # see _apply_auto_share self._panel_side = None self._tight_bbox = None # bounding boxes are saved + self._integration_external = None # explicit external-mode override (None=auto) self.xaxis.isDefault_minloc = True # ensure enabled at start (needed for dual) self.yaxis.isDefault_minloc = True @@ -1409,6 +1456,11 @@ def _add_legend( titlefontcolor=None, handle_kw=None, handler_map=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, **kwargs, ): """ @@ -1446,7 +1498,18 @@ def _add_legend( # Generate and prepare the legend axes if loc in ("fill", "left", "right", "top", "bottom"): - lax = self._add_guide_panel(loc, align, width=width, space=space, pad=pad) + lax = self._add_guide_panel( + loc, + align, + width=width, + space=space, + pad=pad, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + ) kwargs.setdefault("borderaxespad", 0) if not frameon: kwargs.setdefault("borderpad", 0) @@ -1739,6 +1802,7 @@ def _get_legend_handles(self, handler_map=None): handler_map_full = plegend.Legend.get_default_handler_map() handler_map_full = handler_map_full.copy() handler_map_full.update(handler_map or {}) + # Prefer synthetic tagging to exclude helper artists; see _ultraplot_synthetic flag on artists. for ax in axs: for attr in ("lines", "patches", "collections", "containers"): for handle in getattr(ax, attr, []): # guard against API changes @@ -1746,7 +1810,12 @@ def _get_legend_handles(self, handler_map=None): handler = plegend.Legend.get_legend_handler( handler_map_full, handle ) # noqa: E501 - if handler and label and label[0] != "_": + if ( + handler + and label + and label[0] != "_" + and not getattr(handle, "_ultraplot_synthetic", False) + ): handles.append(handle) return handles @@ -1897,11 +1966,17 @@ def _update_guide( if legend: align = legend_kw.pop("align", None) queue = legend_kw.pop("queue", queue_legend) - self.legend(objs, loc=legend, align=align, queue=queue, **legend_kw) + # Avoid immediate legend creation in external context + if not self._in_external_context(): + self.legend(objs, loc=legend, align=align, queue=queue, **legend_kw) if colorbar: align = colorbar_kw.pop("align", None) queue = colorbar_kw.pop("queue", queue_colorbar) - self.colorbar(objs, loc=colorbar, align=align, queue=queue, **colorbar_kw) + # Avoid immediate colorbar creation in external context + if not self._in_external_context(): + self.colorbar( + objs, loc=colorbar, align=align, queue=queue, **colorbar_kw + ) @staticmethod def _parse_frame(guide, fancybox=None, shadow=None, **kwargs): @@ -2423,6 +2498,8 @@ def _legend_label(*objs): # noqa: E301 labs = [] for obj in objs: if hasattr(obj, "get_label"): # e.g. silent list + if getattr(obj, "_ultraplot_synthetic", False): + continue lab = obj.get_label() if lab is not None and not str(lab).startswith("_"): labs.append(lab) @@ -2453,10 +2530,15 @@ def _legend_tuple(*objs): # noqa: E306 if hs: handles.extend(hs) elif obj: # fallback to first element - handles.append(obj[0]) + # Skip synthetic helpers and fill_between collections + if not getattr(obj[0], "_ultraplot_synthetic", False): + handles.append(obj[0]) else: handles.append(obj) elif hasattr(obj, "get_label"): + # Skip synthetic helpers and fill_between collections + if getattr(obj, "_ultraplot_synthetic", False): + continue handles.append(obj) else: warnings._warn_ultraplot(f"Ignoring invalid legend handle {obj!r}.") @@ -2626,9 +2708,19 @@ def _unshare(self, *, which: str): setattr(sibling, f"_share{which}", None) this_ax = getattr(self, f"{which}axis") sib_ax = getattr(sibling, f"{which}axis") - # Reset formatters - this_ax.major = copy.deepcopy(this_ax.major) - this_ax.minor = copy.deepcopy(this_ax.minor) + # Reset formatters by creating new Ticker objects. + # A deepcopy can trigger redraws. + new_major = maxis.Ticker() + if this_ax.major: + new_major.locator = copy.copy(this_ax.major.locator) + new_major.formatter = copy.copy(this_ax.major.formatter) + this_ax.major = new_major + + new_minor = maxis.Ticker() + if this_ax.minor: + new_minor.locator = copy.copy(this_ax.minor.locator) + new_minor.formatter = copy.copy(this_ax.minor.formatter) + this_ax.minor = new_minor def _sharex_setup(self, sharex, **kwargs): """ @@ -2894,6 +2986,8 @@ def _update_title(self, loc, title=None, **kwargs): kw["text"] = title[self.number - 1] else: raise ValueError(f"Invalid title {title!r}. Must be string(s).") + if any(key in kwargs for key in ("size", "fontsize")): + self._title_dict[loc]._ultraplot_manual_size = True kw.update(kwargs) self._title_dict[loc].update(kw) @@ -2906,6 +3000,8 @@ def _update_title_position(self, renderer): # NOTE: Critical to do this every time in case padding changes or # we added or removed an a-b-c label in the same position as a title width, height = self._get_size_inches() + if width <= 0 or height <= 0: + return x_pad = self._title_pad / (72 * width) y_pad = self._title_pad / (72 * height) for loc, obj in self._title_dict.items(): @@ -2950,7 +3046,9 @@ def _update_title_position(self, renderer): # Offset title away from a-b-c label atext, ttext = aobj.get_text(), tobj.get_text() awidth = twidth = 0 - pad = (abcpad / 72) / self._get_size_inches()[0] + width_inches = self._get_size_inches()[0] + pad = (abcpad / 72) / width_inches + abc_pad = (self._abc_pad / 72) / width_inches ha = aobj.get_ha() # Get dimensions of non-empty elements @@ -2967,6 +3065,36 @@ def _update_title_position(self, renderer): .width ) + # Shrink the title font if both texts share a location and would overflow + if ( + atext + and ttext + and self._abc_loc == self._title_loc + and twidth > 0 + and not getattr(tobj, "_ultraplot_manual_size", False) + ): + scale = 1 + base_x = tobj.get_position()[0] + if ha == "left": + available = 1 - (base_x + awidth + pad) + if available < twidth and available > 0: + scale = available / twidth + elif ha == "right": + available = base_x + abc_pad - pad - awidth + if available < twidth and available > 0: + scale = available / twidth + elif ha == "center": + # Conservative fit for centered titles sharing the abc location + left_room = base_x - 0.5 * (awidth + pad) + right_room = 1 - (base_x + 0.5 * (awidth + pad)) + max_room = min(left_room, right_room) + if max_room < twidth / 2 and max_room > 0: + scale = (2 * max_room) / twidth + + if scale < 1: + tobj.set_fontsize(tobj.get_fontsize() * scale) + twidth *= scale + # Calculate offsets based on alignment and content aoffset = toffset = 0 if atext and ttext: @@ -2988,6 +3116,44 @@ def _update_title_position(self, renderer): if ttext: tobj.set_x(tobj.get_position()[0] + toffset) + # Shrink title if it overlaps the abc label at a different location + if ( + atext + and self._abc_loc != self._title_loc + and not getattr( + self._title_dict[self._title_loc], "_ultraplot_manual_size", False + ) + ): + title_obj = self._title_dict[self._title_loc] + title_text = title_obj.get_text() + if title_text: + abc_bbox = aobj.get_window_extent(renderer).transformed( + self.transAxes.inverted() + ) + title_bbox = title_obj.get_window_extent(renderer).transformed( + self.transAxes.inverted() + ) + ax0, ax1 = abc_bbox.x0, abc_bbox.x1 + tx0, tx1 = title_bbox.x0, title_bbox.x1 + if tx0 < ax1 + pad and tx1 > ax0 - pad: + base_x = title_obj.get_position()[0] + ha = title_obj.get_ha() + max_width = 0 + if ha == "left": + if base_x <= ax0 - pad: + max_width = (ax0 - pad) - base_x + elif ha == "right": + if base_x >= ax1 + pad: + max_width = base_x - (ax1 + pad) + elif ha == "center": + if base_x >= ax1 + pad: + max_width = 2 * (base_x - (ax1 + pad)) + elif base_x <= ax0 - pad: + max_width = 2 * ((ax0 - pad) - base_x) + if 0 < max_width < title_bbox.width: + scale = max_width / title_bbox.width + title_obj.set_fontsize(title_obj.get_fontsize() * scale) + def _update_super_title(self, suptitle=None, **kwargs): """ Update the figure super title. @@ -3056,12 +3222,39 @@ def _update_share_labels(self, axes=None, target="x"): target : {'x', 'y'}, optional Which axis labels to share ('x' for x-axis, 'y' for y-axis) """ - if not axes: + if axes is False: + self.figure._clear_share_label_groups([self], target=target) + return + if axes is None or not len(list(axes)): return # Convert indices to actual axes objects if isinstance(axes[0], int): axes = [self.figure.axes[i] for i in axes] + axes = [ + ax._get_topmost_axes() if hasattr(ax, "_get_topmost_axes") else ax + for ax in axes + if ax is not None + ] + if len(axes) < 2: + return + # Preserve order while de-duplicating + seen = set() + unique = [] + for ax in axes: + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + unique.append(ax) + axes = unique + if len(axes) < 2: + return + + # Prefer figure-managed spanning labels when possible + if all(isinstance(ax, maxes.SubplotBase) for ax in axes): + self.figure._register_share_label_group(axes, target=target, source=self) + return # Get the center position of the axes group if box := self.get_center_of_axes(axes): @@ -3322,6 +3515,7 @@ def _label_key(self, side: str) -> str: labelright/labelleft respectively. """ from packaging import version + from ..internals import _version_mpl # TODO: internal deprecation warning when we drop 3.9, we need to remove this @@ -3483,7 +3677,19 @@ def colorbar(self, mappable, values=None, loc=None, location=None, **kwargs): @docstring._concatenate_inherited # also obfuscates params @docstring._snippet_manager - def legend(self, handles=None, labels=None, loc=None, location=None, **kwargs): + def legend( + self, + handles=None, + labels=None, + loc=None, + location=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ): """ Add an inset legend or outer legend along the edge of the axes. @@ -3545,7 +3751,18 @@ def legend(self, handles=None, labels=None, loc=None, location=None, **kwargs): if queue: self._register_guide("legend", (handles, labels), (loc, align), **kwargs) else: - return self._add_legend(handles, labels, loc=loc, align=align, **kwargs) + return self._add_legend( + handles, + labels, + loc=loc, + align=align, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, + ) @docstring._concatenate_inherited @docstring._snippet_manager diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 46685b5df..c115dc45f 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -5,22 +5,27 @@ import copy import inspect +import matplotlib.axis as maxis import matplotlib.dates as mdates import matplotlib.ticker as mticker import numpy as np - from packaging import version from .. import constructor from .. import scale as pscale from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 -from ..internals import _not_none, _pop_rc, _version_mpl, docstring, labels, warnings -from . import plot, shared -import matplotlib.axis as maxis - +from ..internals import ( + _not_none, + _pop_rc, + _version_mpl, + docstring, + ic, # noqa: F401 + labels, + warnings, +) from ..utils import units +from . import plot, shared __all__ = ["CartesianAxes"] @@ -432,9 +437,14 @@ def _apply_axis_sharing_for_axis( # Handle axis label sharing (level > 0) if level > 0: - shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") - labels._transfer_label(axis.label, shared_axis_obj.label) - axis.label.set_visible(False) + if self.figure._is_share_label_group_member(self, axis_name): + pass + elif self.figure._is_share_label_group_member(shared_axis, axis_name): + axis.label.set_visible(False) + else: + shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") + labels._transfer_label(axis.label, shared_axis_obj.label) + axis.label.set_visible(False) # Handle tick label sharing (level > 2) if level > 2: @@ -779,6 +789,26 @@ def _sharey_setup(self, sharey, *, labels=True, limits=True): if level > 1 and limits: self._sharey_limits(sharey) + def _apply_log_formatter_on_scale(self, s): + """ + Enforce log formatter when log scale is set and rc is enabled. + """ + if not rc.find("formatter.log", context=True): + return + if getattr(self, f"get_{s}scale")() != "log": + return + self._update_formatter(s, "log") + + def set_xscale(self, value, **kwargs): + result = super().set_xscale(value, **kwargs) + self._apply_log_formatter_on_scale("x") + return result + + def set_yscale(self, value, **kwargs): + result = super().set_yscale(value, **kwargs) + self._apply_log_formatter_on_scale("y") + return result + def _update_formatter( self, s, @@ -1389,6 +1419,7 @@ def format( # WARNING: Changing axis scale also changes default locators # and formatters, and restricts possible range of axis limits, # so critical to do it first. + scale_requested = scale is not None if scale is not None: scale = constructor.Scale(scale, **scale_kw) getattr(self, f"set_{s}scale")(scale) @@ -1480,10 +1511,40 @@ def format( tickrange=tickrange, wraprange=wraprange, ) + if ( + scale_requested + and formatter is None + and not formatter_kw + and tickrange is None + and wraprange is None + and rc.find("formatter.log", context=True) + and getattr(self, f"get_{s}scale")() == "log" + ): + self._update_formatter(s, "log") # Ensure ticks are within axis bounds self._fix_ticks(s, fixticks=fixticks) + if rc.find("formatter.log", context=True): + if ( + xscale is not None + and xformatter is None + and not xformatter_kw + and xtickrange is None + and xwraprange is None + and self.get_xscale() == "log" + ): + self._update_formatter("x", "log") + if ( + yscale is not None + and yformatter is None + and not yformatter_kw + and ytickrange is None + and ywraprange is None + and self.get_yscale() == "log" + ): + self._update_formatter("y", "log") + # Parent format method if aspect is not None: self.set_aspect(aspect) diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 15c5f9a43..267acb206 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -24,19 +24,19 @@ from .. import constructor from .. import proj as pproj +from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 from ..internals import ( _not_none, _pop_rc, _version_cartopy, docstring, + ic, # noqa: F401 + labels, warnings, ) -from .. import ticker as pticker from ..utils import units -from . import plot -from . import shared +from . import plot, shared try: import cartopy.crs as ccrs @@ -148,6 +148,15 @@ *For cartopy axes only.* Whether to rotate non-inline gridline labels so that they automatically follow the map boundary curvature. +labelrotation : float, optional + The rotation angle in degrees for both longitude and latitude tick labels. + Use `lonlabelrotation` and `latlabelrotation` to set them separately. +lonlabelrotation : float, optional + The rotation angle in degrees for longitude tick labels. + Works for both cartopy and basemap backends. +latlabelrotation : float, optional + The rotation angle in degrees for latitude tick labels. + Works for both cartopy and basemap backends. labelpad : unit-spec, default: :rc:`grid.labelpad` *For cartopy axes only.* The padding between non-inline gridline labels and the map boundary. @@ -653,6 +662,24 @@ def _apply_axis_sharing(self): the leftmost and bottommost is the *figure* sharing level. """ + # Share axis labels + if self._sharex and self.figure._sharex >= 1: + if self.figure._is_share_label_group_member(self, "x"): + pass + elif self.figure._is_share_label_group_member(self._sharex, "x"): + self.xaxis.label.set_visible(False) + else: + labels._transfer_label(self.xaxis.label, self._sharex.xaxis.label) + self.xaxis.label.set_visible(False) + if self._sharey and self.figure._sharey >= 1: + if self.figure._is_share_label_group_member(self, "y"): + pass + elif self.figure._is_share_label_group_member(self._sharey, "y"): + self.yaxis.label.set_visible(False) + else: + labels._transfer_label(self.yaxis.label, self._sharey.yaxis.label) + self.yaxis.label.set_visible(False) + # Share interval x if self._sharex and self.figure._sharex >= 2: self._lonaxis.set_view_interval(*self._sharex._lonaxis.get_view_interval()) @@ -663,6 +690,142 @@ def _apply_axis_sharing(self): self._lataxis.set_view_interval(*self._sharey._lataxis.get_view_interval()) self._lataxis.set_minor_locator(self._sharey._lataxis.get_minor_locator()) + def _apply_aspect_and_adjust_panels(self, *, tol=1e-9): + """ + Apply aspect and then align panels to the adjusted axes box. + + Notes + ----- + Cartopy and basemap use different tolerances when detecting whether + apply_aspect() actually changed the axes position. + """ + self.apply_aspect() + self._adjust_panel_positions(tol=tol) + + def _adjust_panel_positions(self, *, tol=1e-9): + """ + Adjust panel positions to align with the aspect-constrained main axes. + After apply_aspect() shrinks the main axes, panels should flank the actual + map boundaries rather than the full gridspec allocation. + """ + if not getattr(self, "_panel_dict", None): + return # no panels to adjust + + # Current (aspect-adjusted) position + main_pos = getattr(self, "_position", None) or self.get_position() + + # Subplot-spec position before apply_aspect(). This is the true "gridspec slot" + # and remains well-defined even if we temporarily modify axes positions. + try: + ss = self.get_subplotspec() + original_pos = ss.get_position(self.figure) if ss is not None else None + except Exception: + original_pos = None + if original_pos is None: + original_pos = getattr( + self, "_originalPosition", None + ) or self.get_position(original=True) + + # Only adjust if apply_aspect() actually changed the position (tolerance + # avoids float churn that can trigger unnecessary layout updates). + if ( + abs(main_pos.x0 - original_pos.x0) <= tol + and abs(main_pos.y0 - original_pos.y0) <= tol + and abs(main_pos.width - original_pos.width) <= tol + and abs(main_pos.height - original_pos.height) <= tol + ): + return + + # Map original -> adjusted coordinates (only along the "long" axis of the + # panel, so span overrides across subplot rows/cols are preserved). + sx = main_pos.width / original_pos.width if original_pos.width else 1.0 + sy = main_pos.height / original_pos.height if original_pos.height else 1.0 + ox0, oy0 = original_pos.x0, original_pos.y0 + ox1, oy1 = ( + original_pos.x0 + original_pos.width, + original_pos.y0 + original_pos.height, + ) + mx0, my0 = main_pos.x0, main_pos.y0 + + for side, panels in self._panel_dict.items(): + for panel in panels: + # Use the panel subplot-spec box as the baseline (not its current + # original position) to avoid accumulated adjustments. + try: + ss = panel.get_subplotspec() + panel_pos = ( + ss.get_position(panel.figure) if ss is not None else None + ) + except Exception: + panel_pos = None + if panel_pos is None: + panel_pos = panel.get_position(original=True) + px0, py0 = panel_pos.x0, panel_pos.y0 + px1, py1 = ( + panel_pos.x0 + panel_pos.width, + panel_pos.y0 + panel_pos.height, + ) + + # Use _set_position when available to avoid layoutbox side effects + # from public set_position() on newer matplotlib versions. + setter = getattr(panel, "_set_position", panel.set_position) + + if side == "left": + # Calculate original gap between panel and main axes + gap = original_pos.x0 - (panel_pos.x0 + panel_pos.width) + # Position panel to the left of the adjusted main axes + new_x0 = main_pos.x0 - panel_pos.width - gap + if py0 <= oy0 + tol and py1 >= oy1 - tol: + new_y0, new_h = my0, main_pos.height + else: + new_y0 = my0 + (panel_pos.y0 - oy0) * sy + new_h = panel_pos.height * sy + new_pos = [new_x0, new_y0, panel_pos.width, new_h] + elif side == "right": + # Calculate original gap + gap = panel_pos.x0 - (original_pos.x0 + original_pos.width) + # Position panel to the right of the adjusted main axes + new_x0 = main_pos.x0 + main_pos.width + gap + if py0 <= oy0 + tol and py1 >= oy1 - tol: + new_y0, new_h = my0, main_pos.height + else: + new_y0 = my0 + (panel_pos.y0 - oy0) * sy + new_h = panel_pos.height * sy + new_pos = [new_x0, new_y0, panel_pos.width, new_h] + elif side == "top": + # Calculate original gap + gap = panel_pos.y0 - (original_pos.y0 + original_pos.height) + # Position panel above the adjusted main axes + new_y0 = main_pos.y0 + main_pos.height + gap + if px0 <= ox0 + tol and px1 >= ox1 - tol: + new_x0, new_w = mx0, main_pos.width + else: + new_x0 = mx0 + (panel_pos.x0 - ox0) * sx + new_w = panel_pos.width * sx + new_pos = [new_x0, new_y0, new_w, panel_pos.height] + elif side == "bottom": + # Calculate original gap + gap = original_pos.y0 - (panel_pos.y0 + panel_pos.height) + # Position panel below the adjusted main axes + new_y0 = main_pos.y0 - panel_pos.height - gap + if px0 <= ox0 + tol and px1 >= ox1 - tol: + new_x0, new_w = mx0, main_pos.width + else: + new_x0 = mx0 + (panel_pos.x0 - ox0) * sx + new_w = panel_pos.width * sx + new_pos = [new_x0, new_y0, new_w, panel_pos.height] + else: + # Unknown side, skip adjustment + continue + + # Panels typically have aspect='auto', which causes matplotlib to + # reset their *active* position to their *original* position inside + # apply_aspect()/get_position(). Update both so the change persists. + try: + setter(new_pos, which="both") + except TypeError: # older matplotlib + setter(new_pos) + def _get_gridliner_labels( self, bottom=None, @@ -850,6 +1013,9 @@ def format( latlabels=None, lonlabels=None, rotatelabels=None, + labelrotation=None, + lonlabelrotation=None, + latlabelrotation=None, loninline=None, latinline=None, inlinelabels=None, @@ -996,6 +1162,8 @@ def format( rotatelabels = _not_none( rotatelabels, rc.find("grid.rotatelabels", context=True) ) # noqa: E501 + lonlabelrotation = _not_none(lonlabelrotation, labelrotation) + latlabelrotation = _not_none(latlabelrotation, labelrotation) labelpad = _not_none(labelpad, rc.find("grid.labelpad", context=True)) dms = _not_none(dms, rc.find("grid.dmslabels", context=True)) nsteps = _not_none(nsteps, rc.find("grid.nsteps", context=True)) @@ -1028,6 +1196,8 @@ def format( loninline=loninline, latinline=latinline, rotatelabels=rotatelabels, + lonlabelrotation=lonlabelrotation, + latlabelrotation=latlabelrotation, labelpad=labelpad, nsteps=nsteps, ) @@ -1281,6 +1451,7 @@ class _CartopyAxes(GeoAxes, _GeoAxes): _name = "cartopy" _name_aliases = ("geo", "geographic") # default 'geographic' axes _proj_class = Projection + _PANEL_TOL = 1e-9 _proj_north = ( pproj.NorthPolarStereo, pproj.NorthPolarGnomonic, @@ -1544,7 +1715,8 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): # WARNING: The set_extent method tries to set a *rectangle* between the *4* # (x, y) coordinate pairs (each corner), so something like (-180, 180, -90, 90) # will result in *line*, causing error! We correct this here. - eps = 1e-10 # bug with full -180, 180 range when lon_0 != 0 + eps_small = 1e-10 # bug with full -180, 180 range when lon_0 != 0 + eps_label = 0.5 # larger epsilon to ensure boundary labels are included lon0 = self._get_lon0() proj = type(self.projection).__name__ north = isinstance(self.projection, self._proj_north) @@ -1560,7 +1732,12 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): if boundinglat is not None and boundinglat != self._boundinglat: lat0 = 90 if north else -90 lon0 = self._get_lon0() - extent = [lon0 - 180 + eps, lon0 + 180 - eps, boundinglat, lat0] + extent = [ + lon0 - 180 + eps_small, + lon0 + 180 - eps_small, + boundinglat, + lat0, + ] self.set_extent(extent, crs=ccrs.PlateCarree()) self._boundinglat = boundinglat @@ -1577,12 +1754,18 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): lonlim[0] = lon0 - 180 if lonlim[1] is None: lonlim[1] = lon0 + 180 - lonlim[0] += eps + # Expand limits slightly to ensure boundary labels are included + # NOTE: We expand symmetrically (subtract from min, add to max) rather + # than just shifting to avoid excluding boundary gridlines + lonlim[0] -= eps_label + lonlim[1] += eps_label latlim = list(latlim) if latlim[0] is None: latlim[0] = -90 if latlim[1] is None: latlim[1] = 90 + latlim[0] -= eps_label + latlim[1] += eps_label extent = lonlim + latlim self.set_extent(extent, crs=ccrs.PlateCarree()) @@ -1663,9 +1846,9 @@ def _update_gridlines( # NOTE: This will re-apply existing gridline locations if unchanged. if nsteps is not None: gl.n_steps = nsteps - latmax = self._lataxis.get_latmax() - if _version_cartopy >= "0.19": - gl.ylim = (-latmax, latmax) + # Set xlim and ylim for cartopy >= 0.19 to control which labels are displayed + # NOTE: Don't set xlim/ylim here - let cartopy determine from the axes extent + # The extent expansion in _update_extent should be sufficient to include boundary labels longrid = rc._get_gridline_bool(longrid, axis="x", which=which, native=False) if longrid is not None: gl.xlines = longrid @@ -1690,6 +1873,8 @@ def _update_major_gridlines( latinline=None, labelpad=None, rotatelabels=None, + lonlabelrotation=None, + latlabelrotation=None, nsteps=None, ): """ @@ -1729,6 +1914,10 @@ def _update_major_gridlines( gl.y_inline = bool(latinline) if rotatelabels is not None: gl.rotate_labels = bool(rotatelabels) # ignored in cartopy < 0.18 + if lonlabelrotation is not None: + gl.xlabel_style["rotation"] = lonlabelrotation + if latlabelrotation is not None: + gl.ylabel_style["rotation"] = latlabelrotation if latinline is not None or loninline is not None: lon, lat = loninline, latinline b = True if lon and lat else "x" if lon else "y" if lat else None @@ -1797,6 +1986,18 @@ def get_extent(self, crs=None): extent[:2] = [lon0 - 180, lon0 + 180] return extent + @override + def draw(self, renderer=None, *args, **kwargs): + """ + Override draw to adjust panel positions for cartopy axes. + + Cartopy's apply_aspect() can shrink the main axes to enforce the projection + aspect ratio. Panels occupy separate gridspec slots, so we reposition them + after the main axes has applied its aspect but before the panel axes are drawn. + """ + super().draw(renderer, *args, **kwargs) + self._adjust_panel_positions(tol=self._PANEL_TOL) + def get_tightbbox(self, renderer, *args, **kwargs): # Perform extra post-processing steps # For now this just draws the gridliners @@ -1814,8 +2015,9 @@ def get_tightbbox(self, renderer, *args, **kwargs): self.outline_patch._path = clipped_path self.background_patch._path = clipped_path - # Apply aspect - self.apply_aspect() + # Apply aspect, then ensure panels follow the aspect-constrained box. + self._apply_aspect_and_adjust_panels(tol=self._PANEL_TOL) + if _version_cartopy >= "0.23": gridliners = [ a for a in self.artists if isinstance(a, cgridliner.Gridliner) @@ -1891,6 +2093,7 @@ class _BasemapAxes(GeoAxes): "sinu", "vandg", ) + _PANEL_TOL = 1e-6 def __init__(self, *args, map_projection=None, **kwargs): """ @@ -1941,6 +2144,29 @@ def __init__(self, *args, map_projection=None, **kwargs): self._turnoff_tick_labels(self._lonlines_major) self._turnoff_tick_labels(self._latlines_major) + def get_tightbbox(self, renderer, *args, **kwargs): + """ + Get tight bounding box, adjusting panel positions after aspect is applied. + + This ensures panels are properly aligned when saving figures, as apply_aspect() + may be called during the rendering process. + """ + # Apply aspect ratio, then ensure panels follow the aspect-constrained box. + self._apply_aspect_and_adjust_panels(tol=self._PANEL_TOL) + + return super().get_tightbbox(renderer, *args, **kwargs) + + @override + def draw(self, renderer=None, *args, **kwargs): + """ + Override draw to adjust panel positions for basemap axes. + + Basemap projections also rely on apply_aspect() and can shrink the main axes; + panels must be repositioned to flank the visible map boundaries. + """ + super().draw(renderer, *args, **kwargs) + self._adjust_panel_positions(tol=self._PANEL_TOL) + def _turnoff_tick_labels(self, locator: mticker.Formatter): """ For GeoAxes with are dealing with a duality. Basemap axes behave differently than Cartopy axes and vice versa. UltraPlot abstracts away from these by providing GeoAxes. For basemap axes we need to turn off the tick labels as they will be handles by GeoAxis @@ -2108,17 +2334,20 @@ def _update_gridlines( latgrid=None, lonarray=None, latarray=None, + lonlabelrotation=None, + latlabelrotation=None, ): """ Apply changes to the basemap axes. """ latmax = self._lataxis.get_latmax() - for axis, name, grid, array, method in zip( + for axis, name, grid, array, method, rotation in zip( ("x", "y"), ("lon", "lat"), (longrid, latgrid), (lonarray, latarray), ("drawmeridians", "drawparallels"), + (lonlabelrotation, latlabelrotation), ): # Correct lonarray and latarray label toggles by changing from lrbt to lrtb. # Then update the cahced toggle array. This lets us change gridline locs @@ -2173,6 +2402,9 @@ def _update_gridlines( for obj in self._iter_gridlines(objs): if isinstance(obj, mtext.Text): obj.update(kwtext) + # Apply rotation if specified + if rotation is not None: + obj.set_rotation(rotation) else: obj.update(kwlines) @@ -2191,6 +2423,8 @@ def _update_major_gridlines( loninline=None, latinline=None, rotatelabels=None, + lonlabelrotation=None, + latlabelrotation=None, labelpad=None, nsteps=None, ): @@ -2204,6 +2438,8 @@ def _update_major_gridlines( latgrid=latgrid, lonarray=lonarray, latarray=latarray, + lonlabelrotation=lonlabelrotation, + latlabelrotation=latlabelrotation, ) sides = {} for side, lonon, laton in zip( @@ -2226,6 +2462,8 @@ def _update_minor_gridlines(self, longrid=None, latgrid=None, nsteps=None): latgrid=latgrid, lonarray=array, latarray=array, + lonlabelrotation=None, + latlabelrotation=None, ) # Set isDefault_majloc, etc. to True for both axes # NOTE: This cannot be done inside _update_gridlines or minor gridlines diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index dc7ff4f27..526e6ffac 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -8,50 +8,46 @@ import itertools import re import sys +from collections.abc import Callable, Iterable from numbers import Integral, Number +from typing import Any, Iterable, Optional, Union -from typing import Any, Union, Iterable, Optional - -from collections.abc import Callable -from collections.abc import Iterable - -from ..utils import units +import matplotlib as mpl import matplotlib.artist as martist import matplotlib.axes as maxes import matplotlib.cbook as cbook import matplotlib.cm as mcm import matplotlib.collections as mcollections import matplotlib.colors as mcolors -import matplotlib.contour as mcontour import matplotlib.container as mcontainer +import matplotlib.contour as mcontour import matplotlib.image as mimage import matplotlib.lines as mlines import matplotlib.patches as mpatches -import matplotlib.ticker as mticker import matplotlib.pyplot as mplt -import matplotlib as mpl -from packaging import version +import matplotlib.ticker as mticker import numpy as np -from typing import Optional, Union, Any import numpy.ma as ma +from packaging import version from .. import colors as pcolors from .. import constructor, utils from ..config import rc -from ..internals import ic # noqa: F401 from ..internals import ( _get_aliases, _not_none, _pop_kwargs, _pop_params, _pop_props, + _version_mpl, context, docstring, guides, + ic, # noqa: F401 inputs, warnings, - _version_mpl, ) +from ..utils import units from . import base try: @@ -1512,25 +1508,6 @@ def _parse_vert( return kwargs -def _inside_seaborn_call(): - """ - Try to detect `seaborn` calls to `scatter` and `bar` and then automatically - apply `absolute_size` and `absolute_width`. - """ - frame = sys._getframe() - absolute_names = ( - "seaborn.distributions", - "seaborn.categorical", - "seaborn.relational", - "seaborn.regression", - ) - while frame is not None: - if frame.f_globals.get("__name__", "") in absolute_names: - return True - frame = frame.f_back - return False - - class PlotAxes(base.Axes): """ The second lowest-level `~matplotlib.axes.Axes` subclass used by ultraplot. @@ -1566,7 +1543,7 @@ def curved_quiver( The implementation of this function is based on the `dfm_tools` repository. Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py """ - from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet + from .plot_types.curved_quiver import CurvedQuiverSet, CurvedQuiverSolver # Parse inputs arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"]) @@ -2087,7 +2064,7 @@ def _add_error_bars( ): # ugly kludge to check for shading if all(_ is None for _ in (bardata, barstds, barpctiles)): barstds, barpctiles = default_barstds, default_barpctiles - if all(_ is None for _ in (boxdata, boxstds, boxpctile)): + if all(_ is None for _ in (boxdata, boxstds, boxpctiles)): boxstds, boxpctiles = default_boxstds, default_boxpctiles showbars = any( _ is not None and _ is not False for _ in (barstds, barpctiles, bardata) @@ -2237,6 +2214,7 @@ def _add_error_shading( # Draw dark and light shading from distributions or explicit errdata eobjs = [] fill = self.fill_between if vert else self.fill_betweenx + if drawfade: edata, label = inputs._dist_range( y, @@ -2250,7 +2228,29 @@ def _add_error_shading( absolute=True, ) if edata is not None: - eobj = fill(x, *edata, label=label, **fadeprops) + synthetic = False + eff_label = label + if self._in_external_context() and ( + eff_label is None or str(eff_label) in ("y", "ymin", "ymax") + ): + eff_label = "_ultraplot_fade" + synthetic = True + + eobj = fill(x, *edata, label=eff_label, **fadeprops) + if synthetic: + try: + setattr(eobj, "_ultraplot_synthetic", True) + if hasattr(eobj, "set_label"): + eobj.set_label("_ultraplot_fade") + except Exception: + pass + for _obj in guides._iter_iterables(eobj): + try: + setattr(_obj, "_ultraplot_synthetic", True) + if hasattr(_obj, "set_label"): + _obj.set_label("_ultraplot_fade") + except Exception: + pass eobjs.append(eobj) if drawshade: edata, label = inputs._dist_range( @@ -2265,7 +2265,29 @@ def _add_error_shading( absolute=True, ) if edata is not None: - eobj = fill(x, *edata, label=label, **shadeprops) + synthetic = False + eff_label = label + if self._in_external_context() and ( + eff_label is None or str(eff_label) in ("y", "ymin", "ymax") + ): + eff_label = "_ultraplot_shade" + synthetic = True + + eobj = fill(x, *edata, label=eff_label, **shadeprops) + if synthetic: + try: + setattr(eobj, "_ultraplot_synthetic", True) + if hasattr(eobj, "set_label"): + eobj.set_label("_ultraplot_shade") + except Exception: + pass + for _obj in guides._iter_iterables(eobj): + try: + setattr(_obj, "_ultraplot_synthetic", True) + if hasattr(_obj, "set_label"): + _obj.set_label("_ultraplot_shade") + except Exception: + pass eobjs.append(eobj) kwargs["distribution"] = distribution @@ -2547,6 +2569,19 @@ def _parse_1d_format( colorbar_kw_labels = _not_none( kwargs.get("colorbar_kw", {}).pop("values", None), ) + # Track whether the user explicitly provided labels/values so we can + # preserve them even when autolabels is disabled. + _user_labels_explicit = any( + v is not None + for v in ( + label, + labels, + value, + values, + legend_kw_labels, + colorbar_kw_labels, + ) + ) labels = _not_none( label=label, @@ -2586,9 +2621,9 @@ def _parse_1d_format( # Apply the labels or values if labels is not None: - if autovalues: + if autovalues or (value is not None or values is not None): kwargs["values"] = inputs._to_numpy_array(labels) - elif autolabels: + elif autolabels or _user_labels_explicit: kwargs["labels"] = inputs._to_numpy_array(labels) # Apply title for legend or colorbar that uses the labels or values @@ -3054,7 +3089,9 @@ def _parse_cycle( resolved_cycle = constructor.Cycle(cycle, **cycle_kw) case str() if cycle.lower() == "none": resolved_cycle = None - case str() | int() | Iterable(): + case str() | int(): + resolved_cycle = constructor.Cycle(cycle, **cycle_kw) + case _ if isinstance(cycle, Iterable): resolved_cycle = constructor.Cycle(cycle, **cycle_kw) case _: resolved_cycle = None @@ -3626,6 +3663,9 @@ def _apply_plot(self, *pairs, vert=True, **kwargs): objs, xsides = [], [] kws = kwargs.copy() kws.update(_pop_props(kws, "line")) + # Disable auto label inference when in external context + if self._in_external_context(): + kws["autolabels"] = False kws, extents = self._inbounds_extent(**kws) for xs, ys, fmt in self._iter_arg_pairs(*pairs): xs, ys, kw = self._parse_1d_args(xs, ys, vert=vert, **kws) @@ -3775,7 +3815,7 @@ def _apply_beeswarm( orientation: str = "horizontal", n_bins: int = 50, **kwargs, - ) -> "Collection": + ) -> mcollections.Collection: # Parse input parameters ss, _ = self._parse_markersize(ss, **kwargs) @@ -4237,7 +4277,7 @@ def _parse_markersize( if s is not None: s = inputs._to_numpy_array(s) if absolute_size is None: - absolute_size = s.size == 1 or _inside_seaborn_call() + absolute_size = s.size == 1 if not absolute_size or smin is not None or smax is not None: smin = _not_none(smin, 1) smax = _not_none(smax, rc["lines.markersize"] ** (1, 2)[area_size]) @@ -4362,8 +4402,45 @@ def _apply_fill( stacked=None, **kwargs, ): - """ - Apply area shading. + """Apply area shading using `fill_between` or `fill_betweenx`. + + This is the internal implementation for `fill_between`, `fill_betweenx`, + `area`, and `areax`. + + Parameters + ---------- + xs, ys1, ys2 : array-like + The x and y coordinates for the shaded regions. + where : array-like, optional + A boolean mask for the points that should be shaded. + vert : bool, optional + The orientation of the shading. If `True` (default), `fill_between` + is used. If `False`, `fill_betweenx` is used. + negpos : bool, optional + Whether to use different colors for positive and negative shades. + stack : bool, optional + Whether to stack shaded regions. + **kwargs + Additional keyword arguments passed to the matplotlib fill function. + + Notes + ----- + Special handling for plots from external packages (e.g., seaborn): + + When this method is used in a context where plots are generated by + an external library like seaborn, it tags the resulting polygons + (e.g., confidence intervals) as "synthetic". This is done unless a + user explicitly provides a label. + + Synthetic artists are marked with `_ultraplot_synthetic=True` and given + a label starting with an underscore (e.g., `_ultraplot_fill`). This + prevents them from being automatically included in legends, keeping the + legend clean and focused on user-specified elements. + + Seaborn internally generates tags like "y", "ymin", and "ymax" for + vertical fills, and "x", "xmin", "xmax" for horizontal fills. UltraPlot + recognizes these and treats them as synthetic unless a different label + is provided. """ # Parse input arguments kw = kwargs.copy() @@ -4373,34 +4450,73 @@ def _apply_fill( stack = _not_none(stack=stack, stacked=stacked) xs, ys1, ys2, kw = self._parse_1d_args(xs, ys1, ys2, vert=vert, **kw) edgefix_kw = _pop_params(kw, self._fix_patch_edges) + guide_kw = _pop_params(kw, self._update_guide) + + # External override only; no seaborn-based tagging - # Draw patches with default edge width zero + # Draw patches y0 = 0 objs, xsides, ysides = [], [], [] - guide_kw = _pop_params(kw, self._update_guide) for _, n, x, y1, y2, w, kw in self._iter_arg_cols(xs, ys1, ys2, where, **kw): kw = self._parse_cycle(n, **kw) + + # If stacking requested, adjust y arrays if stack: - y1 = y1 + y0 # avoid in-place modification + y1 = y1 + y0 y2 = y2 + y0 - y0 = y0 + y2 - y1 # irrelevant that we added y0 to both - if negpos: # NOTE: if user passes 'where' will issue a warning + y0 = y0 + y2 - y1 + + # External override: if in external mode and no explicit label was provided, + # mark fill as synthetic so it is ignored by legend parsing unless explicitly labeled. + synthetic = False + if self._in_external_context() and ( + kw.get("label", None) is None + or str(kw.get("label")) in ("y", "ymin", "ymax") + ): + kw["label"] = "_ultraplot_fill" + synthetic = True + + # Draw object (negpos splits into two silent_list items) + if negpos: obj = self._call_negpos(name, x, y1, y2, where=w, use_where=True, **kw) else: obj = self._call_native(name, x, y1, y2, where=w, **kw) + + if synthetic: + try: + setattr(obj, "_ultraplot_synthetic", True) + if hasattr(obj, "set_label"): + obj.set_label("_ultraplot_fill") + except Exception: + pass + for art in guides._iter_iterables(obj): + try: + setattr(art, "_ultraplot_synthetic", True) + if hasattr(art, "set_label"): + art.set_label("_ultraplot_fill") + except Exception: + pass + + # No synthetic tagging or seaborn-based label overrides + + # Patch edge fixes self._fix_patch_edges(obj, **edgefix_kw, **kw) + + # Track sides for sticky edges xsides.append(x) for y in (y1, y2): self._inbounds_xylim(extents, x, y, vert=vert) - if y.size == 1: # add sticky edges if bounds are scalar + if y.size == 1: ysides.append(y) objs.append(obj) + # Draw guide and add sticky edges # Draw guide and add sticky edges self._update_guide(objs, **guide_kw) for axis, sides in zip("xy" if vert else "yx", (xsides, ysides)): self._fix_sticky_edges(objs, axis, *sides) return objs[0] if len(objs) == 1 else cbook.silent_list("PolyCollection", objs) + return objs[0] if len(objs) == 1 else cbook.silent_list("PolyCollection", objs) @docstring._snippet_manager def area(self, *args, **kwargs): @@ -4621,7 +4737,7 @@ def _apply_bar( xs, hs, kw = self._parse_1d_args(xs, hs, orientation=orientation, **kw) edgefix_kw = _pop_params(kw, self._fix_patch_edges) if absolute_width is None: - absolute_width = _inside_seaborn_call() + absolute_width = False or self._in_external_context() # Call func after converting bar width b0 = 0 @@ -4705,6 +4821,7 @@ def _add_bar_labels( # Find the maximum extent of text + bar position max_extent = current_lim[1] # Start with current upper limit + w = 0 for label, bar in zip(bar_labels, container): # Get text bounding box bbox = label.get_window_extent(renderer=self.figure.canvas.get_renderer()) @@ -4715,21 +4832,25 @@ def _add_bar_labels( bar_end = bar.get_width() + bar.get_x() text_end = bar_end + bbox_data.width max_extent = max(max_extent, text_end) + w = max(w, bar.get_height()) else: # For vertical bars, check if text extends beyond top edge bar_end = bar.get_height() + bar.get_y() text_end = bar_end + bbox_data.height max_extent = max(max_extent, text_end) + w = max(w, bar.get_width()) # Only adjust limits if text extends beyond current range if max_extent > current_lim[1]: padding = (max_extent - current_lim[1]) * 1.25 # Add a bit of padding new_lim = (current_lim[0], max_extent + padding) getattr(self, f"set_{which}lim")(new_lim) + lim = [getattr(self.dataLim, f"{other_which}{idx}") for idx in range(0, 2)] + lim = (lim[0] - w / 4, lim[1] + w / 4) - # Keep the other axis unchanged - getattr(self, f"set_{other_which}lim")(other_lim) - + current_lim = getattr(self, f"get_{other_which}lim")() + new_lim = (min(lim[0], current_lim[0]), max(lim[1], current_lim[1])) + getattr(self, f"set_{other_which}lim")(new_lim) return bar_labels @inputs._preprocess_or_redirect("x", "height", "width", "bottom") @@ -5252,6 +5373,11 @@ def hexbin(self, x, y, weights, **kwargs): center_levels=center_levels, **kw, ) + # Change the default behavior for weights/C to compute + # the total of the weights, not their average. + reduce_C_function = kw.get("reduce_C_function", None) + if reduce_C_function is None: + kw["reduce_C_function"] = np.sum norm = kw.get("norm", None) if norm is not None and not isinstance(norm, pcolors.DiscreteNorm): norm.vmin = norm.vmax = None # remove nonsense values diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 7c2cd454b..b81ef2cb7 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -6,12 +6,13 @@ import inspect import os from numbers import Integral + from packaging import version try: - from typing import List, Optional, Union, Tuple + from typing import List, Optional, Tuple, Union except ImportError: - from typing_extensions import List, Optional, Union, Tuple + from typing_extensions import List, Optional, Tuple, Union import matplotlib.axes as maxes import matplotlib.figure as mfigure @@ -30,7 +31,6 @@ from . import constructor from . import gridspec as pgridspec from .config import rc, rc_matplotlib -from .internals import ic # noqa: F401 from .internals import ( _not_none, _pop_params, @@ -38,10 +38,11 @@ _translate_loc, context, docstring, + ic, # noqa: F401 labels, warnings, ) -from .utils import units, _get_subplot_layout, _Crawler +from .utils import _Crawler, units __all__ = [ "Figure", @@ -61,6 +62,8 @@ "ams2": 4.5, "ams3": 5.5, "ams4": 6.5, + "cop1": "8.3cm", + "cop2": "12cm", "nat1": "89mm", "nat2": "183mm", "pnas1": "8.7cm", @@ -161,6 +164,9 @@ ``'ams2'`` small 2-column ” ``'ams3'`` medium 2-column ” ``'ams4'`` full 2-column ” + ``'cop1'`` 1-column \ +`Copernicus Publications `_ (e.g. *The Cryosphere*, *Geoscientific Model Development*) + ``'cop2'`` 2-column ” ``'nat1'`` 1-column `Nature Research `_ ``'nat2'`` 2-column ” ``'pnas1'`` 1-column \ @@ -176,6 +182,8 @@ https://www.agu.org/Publish-with-AGU/Publish/Author-Resources/Graphic-Requirements .. _ams: \ https://www.ametsoc.org/ams/index.cfm/publications/authors/journal-and-bams-authors/figure-information-for-authors/ + .. _cop: \ +https://publications.copernicus.org/for_authors/manuscript_preparation.html#figurestables .. _nat: \ https://www.nature.com/nature/for-authors/formatting-guide .. _pnas: \ @@ -806,6 +814,7 @@ def __init__( self._supxlabel_dict = {} # an axes: label mapping self._supylabel_dict = {} # an axes: label mapping self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}} + self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups self._suptitle_pad = rc["suptitle.pad"] d = self._suplabel_props = {} # store the super label props d["left"] = {"va": "center", "ha": "right"} @@ -832,6 +841,7 @@ def draw(self, renderer): # we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars self._share_ticklabels(axis="x") self._share_ticklabels(axis="y") + self._apply_share_label_groups() super().draw(renderer) def _share_ticklabels(self, *, axis: str) -> None: @@ -1385,12 +1395,12 @@ def _add_axes_panel( # Vertical panels: should use rows parameter, not cols if _not_none(cols, col) is not None and _not_none(rows, row) is None: raise ValueError( - f"For {side!r} colorbars (vertical), use 'rows=' or 'row=' " + f"For {side!r} panels (vertical), use 'rows=' or 'row=' " "to specify span, not 'cols=' or 'col='." ) if span is not None and _not_none(rows, row) is None: warnings._warn_ultraplot( - f"For {side!r} colorbars (vertical), prefer 'rows=' over 'span=' " + f"For {side!r} panels (vertical), prefer 'rows=' over 'span=' " "for clarity. Using 'span' as rows." ) span_override = _not_none(rows, row, span) @@ -1398,7 +1408,7 @@ def _add_axes_panel( # Horizontal panels: should use cols parameter, not rows if _not_none(rows, row) is not None and _not_none(cols, col, span) is None: raise ValueError( - f"For {side!r} colorbars (horizontal), use 'cols=' or 'span=' " + f"For {side!r} panels (horizontal), use 'cols=' or 'span=' " "to specify span, not 'rows=' or 'row='." ) span_override = _not_none(cols, col, span) @@ -1881,6 +1891,223 @@ def _align_axis_label(self, x): if span: self._update_axis_label(pos, axs) + # Apply explicit label-sharing groups for this axis + self._apply_share_label_groups(axis=x) + + def _register_share_label_group(self, axes, *, target, source=None): + """ + Register an explicit label-sharing group for a subset of axes. + """ + if not axes: + return + axes = list(axes) + axes = [ax for ax in axes if ax is not None and ax.figure is self] + if len(axes) < 2: + return + + # Preserve order while de-duplicating + seen = set() + unique = [] + for ax in axes: + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + unique.append(ax) + axes = unique + if len(axes) < 2: + return + + # Split by label side if mixed + axes_by_side = {} + if target == "x": + for ax in axes: + axes_by_side.setdefault(ax.xaxis.get_label_position(), []).append(ax) + else: + for ax in axes: + axes_by_side.setdefault(ax.yaxis.get_label_position(), []).append(ax) + if len(axes_by_side) > 1: + for side, side_axes in axes_by_side.items(): + side_source = source if source in side_axes else None + self._register_share_label_group_for_side( + side_axes, target=target, side=side, source=side_source + ) + return + + side, side_axes = next(iter(axes_by_side.items())) + self._register_share_label_group_for_side( + side_axes, target=target, side=side, source=source + ) + + def _register_share_label_group_for_side(self, axes, *, target, side, source=None): + """ + Register a single label-sharing group for a given label side. + """ + if not axes: + return + axes = [ax for ax in axes if ax is not None and ax.figure is self] + if len(axes) < 2: + return + + # Prefer label text from the source axes if available + label = None + if source in axes: + candidate = getattr(source, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + if label is None: + for ax in axes: + candidate = getattr(ax, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + break + + text = label.get_text() if label else "" + props = None + if label is not None: + props = { + "color": label.get_color(), + "fontproperties": label.get_font_properties(), + "rotation": label.get_rotation(), + "rotation_mode": label.get_rotation_mode(), + "ha": label.get_ha(), + "va": label.get_va(), + } + + group_key = tuple(sorted(id(ax) for ax in axes)) + groups = self._share_label_groups[target] + group = groups.get(group_key) + if group is None: + groups[group_key] = { + "axes": axes, + "side": side, + "text": text if text.strip() else "", + "props": props, + } + else: + group["axes"] = axes + group["side"] = side + if text.strip(): + group["text"] = text + group["props"] = props + + def _is_share_label_group_member(self, ax, axis): + """ + Return True if the axes belongs to any explicit label-sharing group. + """ + groups = self._share_label_groups.get(axis, {}) + return any(ax in group["axes"] for group in groups.values()) + + def _has_share_label_groups(self, axis): + """ + Return True if there are any explicit label-sharing groups for an axis. + """ + return bool(self._share_label_groups.get(axis, {})) + + def _clear_share_label_groups(self, axes=None, *, target=None): + """ + Clear explicit label-sharing groups, optionally filtered by axes. + """ + targets = ("x", "y") if target is None else (target,) + for axis in targets: + groups = self._share_label_groups.get(axis, {}) + if axes is None: + groups.clear() + continue + axes_set = {ax for ax in axes if ax is not None} + for key in list(groups): + if any(ax in axes_set for ax in groups[key]["axes"]): + del groups[key] + # Clear any existing spanning labels tied to these axes + if axis == "x": + for ax in axes_set: + if ax in self._supxlabel_dict: + self._supxlabel_dict[ax].set_text("") + else: + for ax in axes_set: + if ax in self._supylabel_dict: + self._supylabel_dict[ax].set_text("") + + def _apply_share_label_groups(self, axis=None): + """ + Apply explicit label-sharing groups, overriding default label sharing. + """ + + def _order_axes_for_side(axs, side): + if side in ("bottom", "top"): + key = ( + (lambda ax: ax._range_subplotspec("y")[1]) + if side == "bottom" + else (lambda ax: ax._range_subplotspec("y")[0]) + ) + reverse = side == "bottom" + else: + key = ( + (lambda ax: ax._range_subplotspec("x")[1]) + if side == "right" + else (lambda ax: ax._range_subplotspec("x")[0]) + ) + reverse = side == "right" + try: + return sorted(axs, key=key, reverse=reverse) + except Exception: + return list(axs) + + axes = (axis,) if axis in ("x", "y") else ("x", "y") + for target in axes: + groups = self._share_label_groups.get(target, {}) + for group in groups.values(): + axs = [ + ax for ax in group["axes"] if ax.figure is self and ax.get_visible() + ] + if len(axs) < 2: + continue + + side = group["side"] + ordered_axs = _order_axes_for_side(axs, side) + + # Refresh label text from any axis with non-empty text + label = None + for ax in ordered_axs: + candidate = getattr(ax, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + break + text = group["text"] + props = group["props"] + if label is not None: + text = label.get_text() + props = { + "color": label.get_color(), + "fontproperties": label.get_font_properties(), + "rotation": label.get_rotation(), + "rotation_mode": label.get_rotation_mode(), + "ha": label.get_ha(), + "va": label.get_va(), + } + group["text"] = text + group["props"] = props + + if not text: + continue + + try: + _, ax = self._get_align_coord( + side, ordered_axs, includepanels=self._includepanels + ) + except Exception: + continue + axlab = getattr(ax, f"{target}axis").label + axlab.set_text(text) + if props is not None: + axlab.set_color(props["color"]) + axlab.set_fontproperties(props["fontproperties"]) + axlab.set_rotation(props["rotation"]) + axlab.set_rotation_mode(props["rotation_mode"]) + axlab.set_ha(props["ha"]) + axlab.set_va(props["va"]) + self._update_axis_label(side, ordered_axs) + def _align_super_labels(self, side, renderer): """ Adjust the position of super labels. @@ -2395,6 +2622,7 @@ def colorbar( if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): try: ax_single = next(iter(ax)) + except (TypeError, StopIteration): ax_single = ax else: @@ -2474,8 +2702,39 @@ def legend( ax = kwargs.pop("ax", None) # Axes panel legend if ax is not None: - leg = ax.legend( - handles, labels, space=space, pad=pad, width=width, **kwargs + # Check if span parameters are provided + has_span = _not_none(span, row, col, rows, cols) is not None + + # Extract a single axes from array if span is provided + # Otherwise, pass the array as-is for normal legend behavior + # Automatically collect handles and labels from spanned axes if not provided + if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): + # Auto-collect handles and labels if not explicitly provided + if handles is None and labels is None: + handles, labels = [], [] + for axi in ax: + h, l = axi.get_legend_handles_labels() + handles.extend(h) + labels.extend(l) + try: + ax_single = next(iter(ax)) + except (TypeError, StopIteration): + ax_single = ax + else: + ax_single = ax + leg = ax_single.legend( + handles, + labels, + loc=loc, + space=space, + pad=pad, + width=width, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, ) # Figure panel legend else: diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 59de0f04c..288f1abc4 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -6,21 +6,24 @@ import itertools import re from collections.abc import MutableSequence +from functools import wraps from numbers import Integral +from typing import List, Optional, Tuple, Union import matplotlib.axes as maxes import matplotlib.gridspec as mgridspec import matplotlib.transforms as mtransforms import numpy as np -from typing import List, Optional, Union, Tuple -from functools import wraps from . import axes as paxes from .config import rc -from .internals import ic # noqa: F401 -from .internals import _not_none, docstring, warnings +from .internals import ( + _not_none, + docstring, + ic, # noqa: F401 + warnings, +) from .utils import _fontsize_to_pt, units -from .internals import warnings __all__ = ["GridSpec", "SubplotGrid"] @@ -1650,7 +1653,10 @@ def __getitem__(self, key): ) new_key.append(encoded_keyi) xs, ys = new_key - objs = grid[xs, ys] + if np.iterable(xs) and np.iterable(ys): + objs = grid[np.ix_(xs, ys)] + else: + objs = grid[xs, ys] if hasattr(objs, "flat"): objs = [obj for obj in objs.flat if obj is not None] elif not isinstance(objs, list): @@ -1743,7 +1749,46 @@ def format(self, **kwargs): ultraplot.figure.Figure.format ultraplot.config.Configurator.context """ + # Implicit label sharing for subset format calls + share_xlabels = kwargs.get("share_xlabels", None) + share_ylabels = kwargs.get("share_ylabels", None) + xlabel = kwargs.get("xlabel", None) + ylabel = kwargs.get("ylabel", None) + axes = [ax for ax in self if ax is not None] + all_axes = set(self.figure._subplot_dict.values()) + is_subset = bool(axes) and all_axes and set(axes) != all_axes + if len(self) > 1: + if share_xlabels is False: + self.figure._clear_share_label_groups(self, target="x") + if share_ylabels is False: + self.figure._clear_share_label_groups(self, target="y") + if is_subset and share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if is_subset and share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") self.figure.format(axs=self, **kwargs) + # Refresh groups after labels are set + if len(self) > 1: + if is_subset and share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if is_subset and share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") + + def share_labels(self, *, axis="x"): + """ + Register an explicit label-sharing group for this subset. + """ + if not self: + return self + axis = axis.lower() + if axis in ("x", "y"): + self.figure._register_share_label_group(self, target=axis) + elif axis in ("both", "all", "xy"): + self.figure._register_share_label_group(self, target="x") + self.figure._register_share_label_group(self, target="y") + else: + raise ValueError(f"Invalid axis={axis!r}. Options are 'x', 'y', or 'both'.") + return self @property def figure(self): diff --git a/ultraplot/scale.py b/ultraplot/scale.py index d83d8f449..8fc2a05b8 100644 --- a/ultraplot/scale.py +++ b/ultraplot/scale.py @@ -11,8 +11,12 @@ import numpy.ma as ma from . import ticker as pticker -from .internals import ic # noqa: F401 -from .internals import _not_none, _version_mpl, warnings +from .internals import ( + _not_none, + _version_mpl, + ic, # noqa: F401 + warnings, +) __all__ = [ "CutoffScale", @@ -370,7 +374,9 @@ def __init__(self, transform=None, invert=False, parent_scale=None, **kwargs): kwsym["linthresh"] = inverse(kwsym["linthresh"]) parent_scale = SymmetricalLogScale(**kwsym) self.functions = (forward, inverse) - self._transform = parent_scale.get_transform() + FuncTransform(forward, inverse) + # Apply the function in data space, then parent scale (e.g., log). + # This ensures dual axes behave correctly when the parent is non-linear. + self._transform = FuncTransform(forward, inverse) + parent_scale.get_transform() # Apply default locators and formatters # NOTE: We pass these through contructor functions diff --git a/ultraplot/tests/test_1dplots.py b/ultraplot/tests/test_1dplots.py index eee2178bb..50bfdc75b 100644 --- a/ultraplot/tests/test_1dplots.py +++ b/ultraplot/tests/test_1dplots.py @@ -5,8 +5,50 @@ import numpy as np import numpy.ma as ma import pandas as pd +import pytest import ultraplot as uplt + + +def test_bar_relative_width_by_default_external_and_internal(): + """ + Bars use relative widths by default regardless of external mode. + """ + x = [0, 10] + h = [1, 2] + + # Internal (external=False): relative width scales with step size + fig, ax = uplt.subplots() + ax.set_external(False) + bars_int = ax.bar(x, h) + w_int = [r.get_width() for r in bars_int.patches] + + # External (external=True): same default relative behavior + fig, ax = uplt.subplots() + ax.set_external(True) + bars_ext = ax.bar(x, h) + w_ext = [r.get_width() for r in bars_ext.patches] + + # With step=10, expect ~ 0.8 * 10 = 8 + assert pytest.approx(w_int[0], rel=1e-6) == 8.0 + assert pytest.approx(w_ext[0], rel=1e-6) == 0.8 + + +def test_bar_absolute_width_manual_override(): + """ + Users can force absolute width by passing absolute_width=True. + """ + x = [0, 10] + h = [1, 2] + + fig, ax = uplt.subplots() + bars_abs = ax.bar(x, h, absolute_width=True) + w_abs = [r.get_width() for r in bars_abs.patches] + + # Absolute width should be the raw width (default 0.8) in data units + assert pytest.approx(w_abs[0], rel=1e-6) == 0.8 + + import pytest diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index 370f2c520..610298bf1 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -4,6 +4,7 @@ """ import numpy as np import pytest + import ultraplot as uplt from ultraplot.internals.warnings import UltraPlotWarning @@ -130,6 +131,119 @@ def test_cartesian_format_all_units_types(): ax.format(**kwargs) +def test_dualx_log_transform_is_finite(): + """ + Ensure dualx transforms remain finite on log axes. + """ + fig, ax = uplt.subplots() + ax.set_xscale("log") + ax.set_xlim(0.1, 10) + sec = ax.dualx(lambda x: 1 / x) + fig.canvas.draw() + + +def test_title_manual_size_ignores_auto_shrink(): + """ + Ensure explicit title sizes bypass auto-scaling. + """ + fig, axs = uplt.subplots(figsize=(2, 2)) + axs.format( + abc=True, + title="X" * 200, + titleloc="left", + abcloc="left", + title_kw={"size": 20}, + ) + title_obj = axs[0]._title_dict["left"] + fig.canvas.draw() + assert title_obj.get_fontsize() == 20 + + +def test_title_shrinks_when_abc_overlaps_different_loc(): + """ + Ensure long titles shrink when overlapping abc at a different location. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 200, titleloc="center", abcloc="left") + title_obj = axs[0]._title_dict["center"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_shrinks_right_aligned_same_location(): + """ + Test that right-aligned titles shrink when they would overflow with abc label. + """ + fig, axs = uplt.subplots(figsize=(2, 2)) + axs.format(abc=True, title="X" * 100, titleloc="right", abcloc="right") + title_obj = axs[0]._title_dict["right"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_shrinks_centered_same_location(): + """ + Test that centered titles shrink when they would overflow with abc label. + """ + fig, axs = uplt.subplots(figsize=(2, 2)) + axs.format(abc=True, title="X" * 150, titleloc="center", abcloc="center") + title_obj = axs[0]._title_dict["center"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_shrinks_right_aligned_different_location(): + """ + Test that right-aligned titles shrink when overlapping abc at different location. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 100, titleloc="right", abcloc="left") + title_obj = axs[0]._title_dict["right"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_shrinks_left_aligned_different_location(): + """ + Test that left-aligned titles shrink when overlapping abc at different location. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 100, titleloc="left", abcloc="right") + title_obj = axs[0]._title_dict["left"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + +def test_title_no_shrink_when_no_overlap(): + """ + Test that titles don't shrink when there's no overlap with abc label. + """ + fig, axs = uplt.subplots(figsize=(4, 2)) + axs.format(abc=True, title="Short Title", titleloc="left", abcloc="right") + title_obj = axs[0]._title_dict["left"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() == original_size + + +def test_title_shrinks_centered_left_of_abc(): + """ + Test that centered titles shrink when they are to the left of abc label. + This covers the specific case where base_x <= ax0 - pad for centered titles. + """ + fig, axs = uplt.subplots(figsize=(3, 2)) + axs.format(abc=True, title="X" * 100, titleloc="center", abcloc="right") + title_obj = axs[0]._title_dict["center"] + original_size = title_obj.get_fontsize() + fig.canvas.draw() + assert title_obj.get_fontsize() < original_size + + def test_axis_access(): # attempt to access the ax object 2d and linearly fig, ax = uplt.subplots(ncols=2, nrows=2) diff --git a/ultraplot/tests/test_colorbar.py b/ultraplot/tests/test_colorbar.py index f16a6f13a..b4e42eb40 100644 --- a/ultraplot/tests/test_colorbar.py +++ b/ultraplot/tests/test_colorbar.py @@ -4,7 +4,48 @@ """ import numpy as np import pytest + import ultraplot as uplt + + +def test_colorbar_defers_external_mode(): + """ + External mode should defer on-the-fly colorbar creation until explicitly requested. + """ + import numpy as np + + fig, ax = uplt.subplots() + ax.set_external(True) + m = ax.pcolor(np.random.random((5, 5)), colorbar="b") + + # No colorbar should have been registered/created yet + assert isinstance(ax[0]._colorbar_dict, dict) + assert len(ax[0]._colorbar_dict) == 0 + + # Explicit colorbar creation should register the colorbar at the requested loc + cb = ax.colorbar(m, loc="b") + assert ("bottom", "center") in ax[0]._colorbar_dict + assert ax[0]._colorbar_dict[("bottom", "center")] is cb + + +def test_explicit_legend_with_handles_under_external_mode(): + """ + Under external mode, legend auto-creation is deferred. Passing explicit handles + to legend() must work immediately. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1], label="LegendLabel", legend="b") + + # No legend queued/created yet + assert ("bottom", "center") not in ax[0]._legend_dict + + # Explicit legend with handle should contain our label + leg = ax.legend(h, loc="b") + labels = [t.get_text() for t in leg.get_texts()] + assert "LegendLabel" in labels + + from itertools import product diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 30911c176..f1efed6ec 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -1,7 +1,11 @@ -import ultraplot as uplt, numpy as np, warnings -import pytest +import warnings from unittest import mock +import numpy as np +import pytest + +import ultraplot as uplt + @pytest.mark.mpl_image_compare def test_geographic_single_projection(): @@ -403,6 +407,45 @@ def test_geo_panel_share_flag_controls_membership(): assert ax2[0]._panel_sharex_group is False +def test_geo_subset_share_xlabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share="labels", span=False) + # GeoAxes.format does not accept xlabel/ylabel; set labels directly. + ax[0, 0].set_xlabel("Top-left X") + ax[0, 1].set_xlabel("Top-right X") + bottom = ax[1, :] + bottom[0].set_xlabel("Bottom-row X") + bottom.format(share_xlabels=list(bottom)) + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_geo_subset_share_xlabels_implicit(): + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share="labels", span=False) + ax[0, 0].set_xlabel("Top-left X") + ax[0, 1].set_xlabel("Top-right X") + bottom = ax[1, :] + bottom[0].set_xlabel("Bottom-row X") + bottom.share_labels(axis="x") + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + def test_geo_non_rectilinear_right_panel_forces_no_share_and_warns(): """ Non-rectilinear Geo projections should not allow panel sharing; adding a right panel @@ -456,7 +499,10 @@ def test_sharing_geo_limits(): after_lat = ax[1]._lataxis.get_view_interval() # We are sharing y which is the latitude axis - assert all([np.allclose(i, j) for i, j in zip(expectation["latlim"], after_lat)]) + # Account for small epsilon expansion in extent (0.5 degrees per side) + assert all( + [np.allclose(i, j, atol=1.0) for i, j in zip(expectation["latlim"], after_lat)] + ) # We are not sharing longitude yet assert all( [ @@ -470,7 +516,10 @@ def test_sharing_geo_limits(): after_lon = ax[1]._lonaxis.get_view_interval() assert all([not np.allclose(i, j) for i, j in zip(before_lon, after_lon)]) - assert all([np.allclose(i, j) for i, j in zip(after_lon, expectation["lonlim"])]) + # Account for small epsilon expansion in extent (0.5 degrees per side) + assert all( + [np.allclose(i, j, atol=1.0) for i, j in zip(after_lon, expectation["lonlim"])] + ) uplt.close(fig) @@ -945,8 +994,9 @@ def test_consistent_range(): lonview = np.array(a._lonaxis.get_view_interval()) latview = np.array(a._lataxis.get_view_interval()) - assert np.allclose(lonview, lonlim) - assert np.allclose(latview, latlim) + # Account for small epsilon expansion in extent (0.5 degrees per side) + assert np.allclose(lonview, lonlim, atol=1.0) + assert np.allclose(latview, latlim, atol=1.0) @pytest.mark.mpl_image_compare @@ -1010,3 +1060,538 @@ def test_grid_indexing_formatting(rng): axs[-1, :].format(lonlabels=True) axs[:, 0].format(latlabels=True) return fig + + +@pytest.mark.parametrize( + "backend", + [ + "cartopy", + "basemap", + ], +) +def test_label_rotation(backend): + """ + Test label rotation parameters for both Cartopy and Basemap backends. + Tests labelrotation, lonlabelrotation, and latlabelrotation parameters. + """ + fig, axs = uplt.subplots(ncols=2, proj="cyl", backend=backend, share=0) + + # Test 1: labelrotation applies to both axes + axs[0].format( + title="Both rotated 45°", + lonlabels="b", + latlabels="l", + labelrotation=45, + lonlines=30, + latlines=30, + ) + + # Test 2: Different rotations for lon and lat + axs[1].format( + title="Lon: 90°, Lat: 0°", + lonlabels="b", + latlabels="l", + lonlabelrotation=90, + latlabelrotation=0, + lonlines=30, + latlines=30, + ) + + # Verify that rotation was applied based on actual backend + if axs[0]._name == "cartopy": + # For Cartopy, check gridliner xlabel_style and ylabel_style + gl0 = axs[0].gridlines_major + assert gl0.xlabel_style.get("rotation") == 45 + assert gl0.ylabel_style.get("rotation") == 45 + + gl1 = axs[1].gridlines_major + assert gl1.xlabel_style.get("rotation") == 90 + assert gl1.ylabel_style.get("rotation") == 0 + + else: # basemap + # For Basemap, check Text object rotation + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + """Extract rotation angles from Text objects in gridlines.""" + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + # Check first axes (both 45°) + lonlines_0, latlines_0 = axs[0].gridlines_major + lon_rotations_0 = get_text_rotations(lonlines_0) + lat_rotations_0 = get_text_rotations(latlines_0) + if lon_rotations_0: # Only check if labels exist + assert all(r == 45 for r in lon_rotations_0) + if lat_rotations_0: + assert all(r == 45 for r in lat_rotations_0) + + # Check second axes (lon: 90°, lat: 0°) + lonlines_1, latlines_1 = axs[1].gridlines_major + lon_rotations_1 = get_text_rotations(lonlines_1) + lat_rotations_1 = get_text_rotations(latlines_1) + if lon_rotations_1: + assert all(r == 90 for r in lon_rotations_1) + if lat_rotations_1: + assert all(r == 0 for r in lat_rotations_1) + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_precedence(backend): + """ + Test that specific rotation parameters take precedence over general labelrotation. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # lonlabelrotation should override labelrotation for lon axis + # latlabelrotation not specified, so should use labelrotation + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=30, + lonlabelrotation=60, # This should override for lon + lonlines=30, + latlines=30, + ) + + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 60 # Override value + assert gl.ylabel_style.get("rotation") == 30 # Fallback value + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 60 for r in lon_rotations) + if lat_rotations: + assert all(r == 30 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_backward_compatibility(): + """ + Test that existing code without rotation parameters still works. + """ + fig, ax = uplt.subplots(proj="cyl") + + # Should work without any rotation parameters + ax.format( + lonlabels="b", + latlabels="l", + lonlines=30, + latlines=30, + ) + + # Verify no rotation was applied (should be default or None) + gl = ax[0]._gridlines_major + # If rotation key doesn't exist or is None/0, that's expected + lon_rotation = gl.xlabel_style.get("rotation") + lat_rotation = gl.ylabel_style.get("rotation") + + # Default rotation should be None or 0 (no rotation) + assert lon_rotation is None or lon_rotation == 0 + assert lat_rotation is None or lat_rotation == 0 + + uplt.close(fig) + + +@pytest.mark.parametrize("rotation_angle", [0, 45, 90, -30, 180]) +def test_label_rotation_angles(rotation_angle): + """ + Test various rotation angles to ensure they're applied correctly. + """ + fig, ax = uplt.subplots(proj="cyl") + + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=rotation_angle, + lonlines=60, + latlines=30, + ) + + gl = ax[0]._gridlines_major + assert gl.xlabel_style.get("rotation") == rotation_angle + assert gl.ylabel_style.get("rotation") == rotation_angle + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_only_lon(backend): + """ + Test rotation applied only to longitude labels. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # Only rotate longitude labels + ax.format( + lonlabels="b", + latlabels="l", + lonlabelrotation=45, + lonlines=30, + latlines=30, + ) + + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 45 + assert gl.ylabel_style.get("rotation") is None + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 45 for r in lon_rotations) + if lat_rotations: + # Default rotation should be 0 + assert all(r == 0 for r in lat_rotations) + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_only_lat(backend): + """ + Test rotation applied only to latitude labels. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # Only rotate latitude labels + ax.format( + lonlabels="b", + latlabels="l", + latlabelrotation=60, + lonlines=30, + latlines=30, + ) + + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") is None + assert gl.ylabel_style.get("rotation") == 60 + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + # Default rotation should be 0 + assert all(r == 0 for r in lon_rotations) + if lat_rotations: + assert all(r == 60 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_with_different_projections(): + """ + Test label rotation with various projections. + """ + projections = ["cyl", "robin", "moll"] + + for proj in projections: + fig, ax = uplt.subplots(proj=proj) + + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=30, + lonlines=60, + latlines=30, + ) + + # For cartopy, verify rotation was set + if ax[0]._name == "cartopy": + gl = ax[0]._gridlines_major + if gl is not None: # Some projections might not support gridlines + assert gl.xlabel_style.get("rotation") == 30 + assert gl.ylabel_style.get("rotation") == 30 + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_with_format_options(backend): + """ + Test label rotation combined with other format options. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # Combine rotation with other formatting + ax.format( + lonlabels="b", + latlabels="l", + lonlabelrotation=45, + latlabelrotation=30, + lonlines=30, + latlines=30, + coast=True, + land=True, + ) + + # Verify rotation was applied + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 45 + assert gl.ylabel_style.get("rotation") == 30 + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 45 for r in lon_rotations) + if lat_rotations: + assert all(r == 30 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_none_values(): + """ + Test that None values for rotation work correctly. + """ + fig, ax = uplt.subplots(proj="cyl") + + # Explicitly set None for rotations + ax.format( + lonlabels="b", + latlabels="l", + lonlabelrotation=None, + latlabelrotation=None, + lonlines=30, + latlines=30, + ) + + gl = ax[0]._gridlines_major + # None should result in no rotation being set + lon_rotation = gl.xlabel_style.get("rotation") + lat_rotation = gl.ylabel_style.get("rotation") + + assert lon_rotation is None or lon_rotation == 0 + assert lat_rotation is None or lat_rotation == 0 + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_label_rotation_update_existing(backend): + """ + Test updating rotation on axes that already have labels. + """ + fig, ax = uplt.subplots(proj="cyl", backend=backend) + + # First format without rotation + ax.format( + lonlabels="b", + latlabels="l", + lonlines=30, + latlines=30, + ) + + # Then update with rotation + ax.format( + lonlabelrotation=45, + latlabelrotation=90, + ) + + # Verify rotation was applied + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.xlabel_style.get("rotation") == 45 + assert gl.ylabel_style.get("rotation") == 90 + else: # basemap + from matplotlib import text as mtext + + def get_text_rotations(gridlines_dict): + rotations = [] + for line_dict in gridlines_dict.values(): + for obj_list in line_dict: + for obj in obj_list: + if isinstance(obj, mtext.Text): + rotations.append(obj.get_rotation()) + return rotations + + lonlines, latlines = ax[0].gridlines_major + lon_rotations = get_text_rotations(lonlines) + lat_rotations = get_text_rotations(latlines) + + if lon_rotations: + assert all(r == 45 for r in lon_rotations) + if lat_rotations: + assert all(r == 90 for r in lat_rotations) + + uplt.close(fig) + + +def test_label_rotation_negative_angles(): + """ + Test various negative rotation angles. + """ + fig, ax = uplt.subplots(proj="cyl") + + negative_angles = [-15, -45, -90, -120, -180] + + for angle in negative_angles: + ax.format( + lonlabels="b", + latlabels="l", + labelrotation=angle, + lonlines=60, + latlines=30, + ) + + gl = ax[0]._gridlines_major + assert gl.xlabel_style.get("rotation") == angle + assert gl.ylabel_style.get("rotation") == angle + + uplt.close(fig) + + +def _check_boundary_labels(ax, expected_lon_labels, expected_lat_labels): + """Helper to check that boundary labels are created and visible.""" + gl = ax._gridlines_major + assert gl is not None, "Gridliner should exist" + + # Check xlim/ylim are expanded beyond actual limits + assert hasattr(gl, "xlim") and hasattr(gl, "ylim") + + # Check longitude labels - only verify the visible ones match expected + lon_texts = [ + label.get_text() for label in gl.bottom_label_artists if label.get_visible() + ] + assert len(lon_texts) == len(expected_lon_labels), ( + f"Should have {len(expected_lon_labels)} visible longitude labels, " + f"got {len(lon_texts)}: {lon_texts}" + ) + for expected in expected_lon_labels: + assert any( + expected in text for text in lon_texts + ), f"{expected} label should be visible, got: {lon_texts}" + + # Check latitude labels - only verify the visible ones match expected + lat_texts = [ + label.get_text() for label in gl.left_label_artists if label.get_visible() + ] + assert len(lat_texts) == len(expected_lat_labels), ( + f"Should have {len(expected_lat_labels)} visible latitude labels, " + f"got {len(lat_texts)}: {lat_texts}" + ) + for expected in expected_lat_labels: + assert any( + expected in text for text in lat_texts + ), f"{expected} label should be visible, got: {lat_texts}" + + +def test_boundary_labels_positive_longitude(): + """ + Test that boundary labels are visible with positive longitude limits. + + This tests the fix for the issue where setting lonlim/latlim would hide + the outermost labels because cartopy's gridliner was filtering them out. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format( + lonlim=(120, 130), + latlim=(10, 20), + lonlocator=[120, 125, 130], + latlocator=[10, 15, 20], + labels=True, + grid=False, + ) + fig.canvas.draw() + _check_boundary_labels(ax[0], ["120°E", "125°E", "130°E"], ["10°N", "15°N", "20°N"]) + uplt.close(fig) + + +def test_boundary_labels_negative_longitude(): + """ + Test that boundary labels are visible with negative longitude limits. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format( + lonlim=(-120, -60), + latlim=(20, 50), + lonlocator=[-120, -90, -60], + latlocator=[20, 35, 50], + labels=True, + grid=False, + ) + fig.canvas.draw() + # Note: Cartopy hides the boundary label at 20°N due to it being exactly at the limit + # This is expected cartopy behavior with floating point precision at boundaries + _check_boundary_labels( + ax[0], + ["120°W", "90°W", "60°W"], + ["20°N", "35°N", "50°N"], + ) + uplt.close(fig) + + +def test_boundary_labels_view_intervals(): + """ + Test that view intervals match requested limits after setting lonlim/latlim. + """ + fig, ax = uplt.subplots(proj="pcarree") + ax.format(lonlim=(0, 60), latlim=(-20, 40), lonlines=30, latlines=20, labels=True) + loninterval = ax[0]._lonaxis.get_view_interval() + latinterval = ax[0]._lataxis.get_view_interval() + assert abs(loninterval[0] - 0) < 1 and abs(loninterval[1] - 60) < 1 + assert abs(latinterval[0] - (-20)) < 1 and abs(latinterval[1] - 40) < 1 + uplt.close(fig) diff --git a/ultraplot/tests/test_integration.py b/ultraplot/tests/test_integration.py index fc1d48b90..7429fafc0 100644 --- a/ultraplot/tests/test_integration.py +++ b/ultraplot/tests/test_integration.py @@ -2,10 +2,73 @@ """ Test xarray, pandas, pint, seaborn integration. """ -import numpy as np, pandas as pd, seaborn as sns -import xarray as xr -import ultraplot as uplt, pytest +import numpy as np +import pandas as pd import pint +import pytest +import seaborn as sns +import xarray as xr + +import ultraplot as uplt + + +def test_seaborn_helpers_filtered_from_legend(): + """ + Seaborn-generated helper artists (e.g., CI bands) must be synthetic-tagged and + filtered out of legends so that only hue categories appear. + """ + fig, ax = uplt.subplots() + + # Create simple dataset with two hue levels + df = pd.DataFrame( + { + "x": np.concatenate([np.arange(10)] * 2), + "y": np.concatenate([np.arange(10), np.arange(10) + 1]), + "hue": ["h1"] * 10 + ["h2"] * 10, + } + ) + + # Use explicit external mode to engage UL's integration behavior for helper artists + with ax.external(): + sns.lineplot(data=df, x="x", y="y", hue="hue", ax=ax) + + # Explicitly create legend and verify labels + leg = ax.legend() + labels = {t.get_text() for t in leg.get_texts()} + + # Only hue labels should be present + assert {"h1", "h2"}.issubset(labels) + + # Spurious or synthetic labels must not appear + for bad in ( + "y", + "ymin", + "ymax", + "_ultraplot_fill", + "_ultraplot_shade", + "_ultraplot_fade", + ): + assert bad not in labels + + +def test_user_labeled_shading_appears_in_legend(): + """ + User-labeled shading (fill_between) must appear in legend even after seaborn plotting. + """ + fig, ax = uplt.subplots() + + # Seaborn plot first (to ensure seaborn context was present earlier) + df = pd.DataFrame({"x": np.arange(10), "y": np.arange(10)}) + sns.lineplot(data=df, x="x", y="y", ax=ax, label="line") + + # Add explicit user-labeled shading on the same axes + x = np.arange(10) + ax.fill_between(x, x - 0.2, x + 0.2, alpha=0.2, label="CI band") + + # Legend must include both the seaborn line label and our shaded band label + leg = ax.legend() + labels = {t.get_text() for t in leg.get_texts()} + assert "CI band" in labels @pytest.mark.mpl_image_compare @@ -96,18 +159,35 @@ def test_seaborn_swarmplot(): @pytest.mark.mpl_image_compare def test_seaborn_hist(rng): """ - Test seaborn histograms. + Test seaborn histograms (smoke test using external mode contexts). """ fig, axs = uplt.subplots(ncols=2, nrows=2) - sns.histplot(rng.normal(size=100), ax=axs[0]) - sns.kdeplot(x=rng.random(100), y=rng.random(100), ax=axs[1]) + + with axs[0].external(): + sns.histplot(rng.normal(size=100), ax=axs[0]) + + with axs[1].external(): + sns.kdeplot(x=rng.random(100), y=rng.random(100), ax=axs[1]) + penguins = sns.load_dataset("penguins") - sns.histplot( - data=penguins, x="flipper_length_mm", hue="species", multiple="stack", ax=axs[2] - ) - sns.kdeplot( - data=penguins, x="flipper_length_mm", hue="species", multiple="stack", ax=axs[3] - ) + + with axs[2].external(): + sns.histplot( + data=penguins, + x="flipper_length_mm", + hue="species", + multiple="stack", + ax=axs[2], + ) + + with axs[3].external(): + sns.kdeplot( + data=penguins, + x="flipper_length_mm", + hue="species", + multiple="stack", + ax=axs[3], + ) return fig diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 096b10729..6b984a55e 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -2,7 +2,11 @@ """ Test legends. """ -import numpy as np, pandas as pd, ultraplot as uplt, pytest +import numpy as np +import pandas as pd +import pytest + +import ultraplot as uplt @pytest.mark.mpl_image_compare @@ -219,3 +223,309 @@ def test_sync_label_dict(rng): 0 ]._legend_dict, "Old legend not removed from dict" uplt.close(fig) + + +def test_external_mode_defers_on_the_fly_legend(): + """ + External mode should defer on-the-fly legend creation until explicitly requested. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1], label="a", legend="b") + + # No legend should have been created yet + assert getattr(ax[0], "legend_", None) is None + + # Explicit legend creation should include the plotted label + leg = ax.legend(h, loc="b") + labels = [t.get_text() for t in leg.get_texts()] + assert "a" in labels + uplt.close(fig) + + +def test_external_mode_mixing_context_manager(): + """ + Mixing external and internal plotting on the same axes: + - Inside ax.external(): on-the-fly legend is deferred + - Outside: UltraPlot-native plotting resumes as normal + - Final explicit ax.legend() consolidates both kinds of artists + """ + fig, ax = uplt.subplots() + + with ax.external(): + (ext,) = ax.plot([0, 1], label="ext", legend="b") # deferred + + (intr,) = ax.line([0, 1], label="int") # normal UL behavior + + leg = ax.legend([ext, intr], loc="b") + labels = {t.get_text() for t in leg.get_texts()} + assert {"ext", "int"}.issubset(labels) + uplt.close(fig) + + +def test_external_mode_toggle_enables_auto(): + """ + Toggling external mode back off should resume on-the-fly guide creation. + """ + fig, ax = uplt.subplots() + + ax.set_external(True) + (ha,) = ax.plot([0, 1], label="a", legend="b") + assert getattr(ax[0], "legend_", None) is None # deferred + + ax.set_external(False) + (hb,) = ax.plot([0, 1], label="b", legend="b") + # Now legend is queued for creation; verify it is registered in the outer legend dict + assert ("bottom", "center") in ax[0]._legend_dict + + # Ensure final legend contains both entries + leg = ax.legend([ha, hb], loc="b") + labels = {t.get_text() for t in leg.get_texts()} + assert {"a", "b"}.issubset(labels) + uplt.close(fig) + + +def test_synthetic_handles_filtered(): + """ + Synthetic-tagged helper artists must be ignored by legend parsing even when + explicitly passed as handles. + """ + fig, ax = uplt.subplots() + (h1,) = ax.plot([0, 1], label="visible") + (h2,) = ax.plot([1, 0], label="helper") + # Mark helper as synthetic; it should be filtered out from legend entries + setattr(h2, "_ultraplot_synthetic", True) + + leg = ax.legend([h1, h2], loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "visible" in labels + assert "helper" not in labels + uplt.close(fig) + + +def test_fill_between_included_in_legend(): + """ + Legitimate fill_between/area handles must appear in legends (regression for + previously skipped FillBetweenPolyCollection). + """ + fig, ax = uplt.subplots() + x = np.arange(5) + y1 = np.zeros(5) + y2 = np.ones(5) + ax.fill_between(x, y1, y2, label="band") + + leg = ax.legend(loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "band" in labels + uplt.close(fig) + + +def test_legend_span_bottom(): + """Test bottom legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend below row 1, spanning columns 1-2 + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created + assert leg is not None + + +def test_legend_span_top(): + """Test top legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend above row 2, spanning columns 2-3 + leg = fig.legend(ax=axs[1, :], cols=(2, 3), loc="top") + + assert leg is not None + + +def test_legend_span_right(): + """Test right legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend right of column 1, spanning rows 1-2 + leg = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + + assert leg is not None + + +def test_legend_span_left(): + """Test left legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend left of column 2, spanning rows 2-3 + leg = fig.legend(ax=axs[:, 1], rows=(2, 3), loc="left") + + assert leg is not None + + +def test_legend_span_validation_left_with_cols_error(): + """Test that LEFT legend raises error with cols parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="left.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="left") + + +def test_legend_span_validation_right_with_cols_error(): + """Test that RIGHT legend raises error with cols parameter.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="right.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="right") + + +def test_legend_span_validation_top_with_rows_error(): + """Test that TOP legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="top.*horizontal.*use 'cols='.*not 'rows='"): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="top") + + +def test_legend_span_validation_bottom_with_rows_error(): + """Test that BOTTOM legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises( + ValueError, match="bottom.*horizontal.*use 'cols='.*not 'rows='" + ): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="bottom") + + +def test_legend_span_validation_left_with_span_warns(): + """Test that LEFT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="left.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="left") + assert leg is not None + + +def test_legend_span_validation_right_with_span_warns(): + """Test that RIGHT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="right.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="right") + assert leg is not None + + +def test_legend_array_without_span(): + """Test that legend on array without span preserves original behavior.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should create legend for all axes in the array + leg = fig.legend(ax=axs[:], loc="right") + assert leg is not None + + +def test_legend_array_with_span(): + """Test that legend on array with span uses first axis + span extent.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should use first axis position with span extent + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + assert leg is not None + + +def test_legend_row_without_span(): + """Test that legend on row without span spans entire row.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 columns + leg = fig.legend(ax=axs[0, :], loc="bottom") + assert leg is not None + + +def test_legend_column_without_span(): + """Test that legend on column without span spans entire column.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 rows + leg = fig.legend(ax=axs[:, 0], loc="right") + assert leg is not None + + +def test_legend_multiple_sides_with_span(): + """Test multiple legends on different sides with span control.""" + fig, axs = uplt.subplots(nrows=3, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Create legends on all 4 sides with different spans + leg_bottom = fig.legend(ax=axs[0, 0], span=(1, 2), loc="bottom") + leg_top = fig.legend(ax=axs[1, 0], span=(2, 3), loc="top") + leg_right = fig.legend(ax=axs[0, 0], rows=(1, 2), loc="right") + leg_left = fig.legend(ax=axs[0, 1], rows=(2, 3), loc="left") + + assert leg_bottom is not None + assert leg_top is not None + assert leg_right is not None + assert leg_left is not None + + +def test_legend_auto_collect_handles_labels_with_span(): + """Test automatic collection of handles and labels from multiple axes with span parameters.""" + + fig, axs = uplt.subplots(nrows=2, ncols=2) + + # Create different plots in each subplot with labels + axs[0, 0].plot([0, 1], [0, 1], label="line1") + axs[0, 1].plot([0, 1], [1, 0], label="line2") + axs[1, 0].scatter([0.5], [0.5], label="point1") + axs[1, 1].scatter([0.5], [0.5], label="point2") + + # Test automatic collection with span parameter (no explicit handles/labels) + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created and contains all handles/labels from both axes + assert leg is not None + assert len(leg.get_texts()) == 2 # Should have 2 labels (line1, line2) + + # Test with rows parameter + leg2 = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + assert leg2 is not None + assert len(leg2.get_texts()) == 2 # Should have 2 labels (line1, point1) + + +def test_legend_explicit_handles_labels_override_auto_collection(): + """Test that explicit handles/labels override auto-collection.""" + + fig, axs = uplt.subplots(nrows=1, ncols=2) + + # Create plots with labels + (h1,) = axs[0].plot([0, 1], [0, 1], label="auto_label1") + (h2,) = axs[1].plot([0, 1], [1, 0], label="auto_label2") + + # Test with explicit handles/labels (should override auto-collection) + custom_handles = [h1] + custom_labels = ["custom_label"] + leg = fig.legend( + ax=axs, span=(1, 2), loc="bottom", handles=custom_handles, labels=custom_labels + ) + + # Verify legend uses explicit handles/labels, not auto-collected ones + assert leg is not None + assert len(leg.get_texts()) == 1 + assert leg.get_texts()[0].get_text() == "custom_label" diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index e3eb9455d..1bcb69684 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -1,17 +1,93 @@ -from cycler import V -import pandas as pd -from pandas.core.arrays.arrow.accessors import pa -import ultraplot as uplt, pytest, numpy as np from unittest import mock from unittest.mock import patch +import numpy as np +import pandas as pd +import pytest +from cycler import V +from pandas.core.arrays.arrow.accessors import pa + +import ultraplot as uplt from ultraplot.internals.warnings import UltraPlotWarning + +@pytest.mark.mpl_image_compare +def test_seaborn_lineplot_legend_hue_only(): + """ + Regression test: seaborn lineplot on UltraPlot axes should not add spurious + legend entries like 'y'/'ymin'. Only hue categories should appear unless the user + explicitly labels helper bands. + """ + import seaborn as sns + + fig, ax = uplt.subplots() + df = pd.DataFrame( + { + "xcol": np.concatenate([np.arange(10)] * 2), + "ycol": np.concatenate([np.arange(10), 1.5 * np.arange(10)]), + "hcol": ["h1"] * 10 + ["h2"] * 10, + } + ) + + with ax.external(): + sns.lineplot(data=df, x="xcol", y="ycol", hue="hcol", ax=ax) + + # Create (or refresh) legend and collect labels + leg = ax.legend() + labels = {t.get_text() for t in leg.get_texts()} + + # Should contain only hue levels; must not contain inferred 'y' or CI helpers + assert "y" not in labels + assert "ymin" not in labels + assert {"h1", "h2"}.issubset(labels) + return fig + + """ This file is used to test base properties of ultraplot.axes.plot. For higher order plotting related functions, please use 1d and 2plots """ +def test_external_preserves_explicit_label(): + """ + In external mode, explicit labels must still be respected even when autolabels are disabled. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1, 2], [0, 1, 0], label="X") + leg = ax.legend(h, loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "X" in labels + + +def test_external_disables_autolabels_no_label(): + """ + In external mode, if no labels are provided, autolabels are disabled and a placeholder is used. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + (h,) = ax.plot([0, 1, 2], [0, 1, 0]) + # Explicitly pass the handle so we test the label on that artist + leg = ax.legend(h, loc="best") + labels = [t.get_text() for t in leg.get_texts()] + # With no explicit labels and autolabels disabled, a placeholder is used + assert (not labels) or (labels[0] in ("_no_label", "")) + + +def test_error_shading_explicit_label_external(): + """ + Explicit label on fill_between should be preserved in legend entries. + """ + fig, ax = uplt.subplots() + ax.set_external(True) + x = np.linspace(0, 2 * np.pi, 50) + y = np.sin(x) + patch = ax.fill_between(x, y - 0.5, y + 0.5, alpha=0.3, label="Band") + leg = ax.legend([patch], loc="best") + labels = [t.get_text() for t in leg.get_texts()] + assert "Band" in labels + + def test_graph_nodes_kw(): """Test the graph method by setting keywords for nodes""" import networkx as nx @@ -285,6 +361,33 @@ def reset(ax): uplt.close(fig) +def test_format_log_scale_preserves_log_formatter(): + """ + Test that setting a log scale preserves the log formatter when enabled. + """ + x = np.linspace(1, 1e6, 10) + log_formatter = uplt.constructor.Formatter("log") + log_formatter_type = type(log_formatter) + + with uplt.rc.context({"formatter.log": True}): + fig, ax = uplt.subplots() + ax.plot(x, x) + ax.format(yscale="log") + assert isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + ax.set_yscale("log") + assert isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + + with uplt.rc.context({"formatter.log": False}): + fig, ax = uplt.subplots() + ax.plot(x, x) + ax.format(yscale="log") + assert not isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + ax.set_yscale("log") + assert not isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + + uplt.close(fig) + + def test_shading_pcolor(rng): """ Pcolormesh by default adjusts the plot by @@ -617,3 +720,38 @@ def test_curved_quiver_color_and_cmap(rng, cmap): fig, ax = uplt.subplots() ax.curved_quiver(X, Y, U, V, color=color, cmap=cmap) return fig + + +def test_histogram_norms(): + """ + Check that all histograms-like plotting functions + use the sum of the weights. + """ + rng = np.random.default_rng(seed=100) + x, y = rng.normal(size=(2, 100)) + w = rng.uniform(size=100) + + fig, axs = uplt.subplots() + _, _, bars = axs.hist(x, weights=w, bins=5) + tot_weights = np.sum([bar.get_height() for bar in bars]) + np.testing.assert_allclose(tot_weights, np.sum(w)) + + fig, axs = uplt.subplots() + _, _, _, qm = axs.hist2d(x, y, weights=w, bins=5) + tot_weights = np.sum(qm.get_array()) + np.testing.assert_allclose(tot_weights, np.sum(w)) + + fig, axs = uplt.subplots() + pc = axs.hexbin(x, y, weights=w, gridsize=5) + tot_weights = np.sum(pc.get_array()) + np.testing.assert_allclose(tot_weights, np.sum(w)) + + # check that a different reduce_C_function produces + # a different result + fig, axs = uplt.subplots() + pc = axs.hexbin(x, y, weights=w, gridsize=5, reduce_C_function=np.max) + tot_weights = np.sum(pc.get_array()) + # check they are not equal and that the different is not + # due to floating point errors + assert tot_weights != np.sum(w) + assert not np.allclose(tot_weights, np.sum(w)) diff --git a/ultraplot/tests/test_statistical_plotting.py b/ultraplot/tests/test_statistical_plotting.py index c65f25245..d1aff89c3 100644 --- a/ultraplot/tests/test_statistical_plotting.py +++ b/ultraplot/tests/test_statistical_plotting.py @@ -71,3 +71,25 @@ def test_panel_dist(rng): px.hist(x, bins, color=color, fill=True, ec="k") px.format(grid=False, ylocator=[], title=title, titleloc="l") return fig + + +@pytest.mark.mpl_image_compare +def test_input_violin_box_options(): + """ + Test various box options in violin plots. + """ + data = np.array([0, 1, 2, 3]).reshape(-1, 1) + + fig, axes = uplt.subplots(ncols=4) + axes[0].bar(data, median=True, boxpctiles=True, bars=False) + axes[0].format(title="boxpctiles") + + axes[1].bar(data, median=True, boxpctile=True, bars=False) + axes[1].format(title="boxpctile") + + axes[2].bar(data, median=True, boxstd=True, bars=False) + axes[2].format(title="boxstd") + + axes[3].bar(data, median=True, boxstds=True, bars=False) + axes[3].format(title="boxstds") + return fig diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 3ebe5f37d..eb42c79fc 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -2,7 +2,10 @@ """ Test subplot layout. """ -import numpy as np, ultraplot as uplt, pytest +import numpy as np +import pytest + +import ultraplot as uplt @pytest.mark.mpl_image_compare @@ -207,7 +210,7 @@ def test_reference_aspect(test_case, refwidth, kwargs, setup_func, ref): # Apply auto layout fig.auto_layout() # Assert reference width accuracy - assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0]) + assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0], rtol=1e-3) return fig @@ -255,6 +258,206 @@ def test_axis_sharing(share): return fig +def test_subset_share_xlabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(xlabel="Top-left X") + ax[0, 1].format(xlabel="Top-right X") + bottom = ax[1, :] + bottom[0].format(xlabel="Bottom-row X", share_xlabels=list(bottom)) + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(xlabel="Top-left X") + ax[0, 1].format(xlabel="Top-right X") + bottom = ax[1, :] + bottom.format(xlabel="Bottom-row X") + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_ylabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(ylabel="Left-top Y") + ax[1, 0].format(ylabel="Left-bottom Y") + right = ax[:, 1] + right[0].format(ylabel="Right-column Y", share_ylabels=list(right)) + + fig.canvas.draw() + + assert ax[0, 0].yaxis.get_label().get_visible() + assert ax[0, 0].get_ylabel() == "Left-top Y" + assert ax[1, 0].yaxis.get_label().get_visible() + assert ax[1, 0].get_ylabel() == "Left-bottom Y" + assert right[0].get_ylabel().strip() == "" + assert right[1].get_ylabel().strip() == "" + assert any( + lab.get_text() == "Right-column Y" for lab in fig._supylabel_dict.values() + ) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit_column(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right.format(xlabel="Right-column X") + + fig.canvas.draw() + + assert ax[0, 1].get_xlabel().strip() == "" + assert ax[1, 1].get_xlabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supxlabel_dict.items() + if lab.get_text() == "Right-column X" + ] + assert label_axes and label_axes[0] is ax[1, 1] + + uplt.close(fig) + + +def test_subset_share_ylabels_implicit_row(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + top = ax[0, :] + top.format(ylabel="Top-row Y") + + fig.canvas.draw() + + assert ax[0, 0].get_ylabel().strip() == "" + assert ax[0, 1].get_ylabel().strip() == "" + label_axes = [ + axi for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y" + ] + assert label_axes and label_axes[0] is ax[0, 0] + + uplt.close(fig) + + +def test_subset_share_xlabels_clear(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + bottom = ax[1, :] + bottom.format(xlabel="Shared") + + fig.canvas.draw() + assert any(lab.get_text() == "Shared" for lab in fig._supxlabel_dict.values()) + + bottom.format(share_xlabels=False, xlabel="Unshared") + fig.canvas.draw() + + assert not any(lab.get_text() == "Shared" for lab in fig._supxlabel_dict.values()) + assert not any(lab.get_text() == "Unshared" for lab in fig._supxlabel_dict.values()) + assert bottom[0].get_xlabel() == "Unshared" + assert bottom[1].get_xlabel() == "Unshared" + + uplt.close(fig) + + +def test_subset_share_labels_method_both(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right[0].set_xlabel("Right-column X") + right[0].set_ylabel("Right-column Y") + right.share_labels(axis="both") + + fig.canvas.draw() + + assert right[0].get_xlabel().strip() == "" + assert right[1].get_xlabel().strip() == "" + assert right[0].get_ylabel().strip() == "" + assert right[1].get_ylabel().strip() == "" + assert any( + lab.get_text() == "Right-column X" for lab in fig._supxlabel_dict.values() + ) + assert any( + lab.get_text() == "Right-column Y" for lab in fig._supylabel_dict.values() + ) + + uplt.close(fig) + + +def test_subset_share_labels_invalid_axis(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + with pytest.raises(ValueError): + ax[:, 1].share_labels(axis="nope") + + uplt.close(fig) + + +def test_subset_share_xlabels_mixed_sides(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + ax[0, :].format(xlabelloc="top", share_xlabels=False) + ax[1, :].format(xlabelloc="bottom", share_xlabels=False) + ax[0, 0].set_xlabel("Top X") + ax[0, 1].set_xlabel("Top X") + ax[1, 0].set_xlabel("Bottom X") + ax[1, 1].set_xlabel("Bottom X") + ax[0, 0].format(share_xlabels=list(ax)) + + fig.canvas.draw() + + assert any(lab.get_text() == "Top X" for lab in fig._supxlabel_dict.values()) + assert any(lab.get_text() == "Bottom X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit_column_top(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right.format(xlabel="Right-column X (top)", xlabelloc="top") + + fig.canvas.draw() + + assert ax[0, 1].get_xlabel().strip() == "" + assert ax[1, 1].get_xlabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supxlabel_dict.items() + if lab.get_text() == "Right-column X (top)" + ] + assert label_axes and label_axes[0] is ax[0, 1] + + uplt.close(fig) + + +def test_subset_share_ylabels_implicit_row_right(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + top = ax[0, :] + top.format(ylabel="Top-row Y (right)", ylabelloc="right") + + fig.canvas.draw() + + assert ax[0, 0].get_ylabel().strip() == "" + assert ax[0, 1].get_ylabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supylabel_dict.items() + if lab.get_text() == "Top-row Y (right)" + ] + assert label_axes and label_axes[0] is ax[0, 1] + + uplt.close(fig) + + @pytest.mark.parametrize( "layout", [ diff --git a/ultraplot/ticker.py b/ultraplot/ticker.py index 4dfa7bfc7..ad1da9519 100644 --- a/ultraplot/ticker.py +++ b/ultraplot/ticker.py @@ -64,7 +64,7 @@ when `zerotrim` is ``True`` and ``2`` otherwise. """ _zerotrim_docstring = """ -zerotrim : bool, default: :rc:`format.zerotrim` +zerotrim : bool, default: :rc:`formatter.zerotrim` Whether to trim trailing decimal zeros. """ _auto_docstring = """