diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 7ce4f3b6c..2cf34bdc3 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -576,7 +576,7 @@ def _sharex_limits(self, sharex): if ax1.get_autoscalex_on() and not ax2.get_autoscalex_on(): ax1.set_xlim(ax2.get_xlim()) # non-default limits # Copy non-default locators and formatters - self.get_shared_x_axes().joined(self, sharex) # share limit/scale changes + self.sharex(sharex) if sharex.xaxis.isDefault_majloc and not self.xaxis.isDefault_majloc: sharex.xaxis.set_major_locator(self.xaxis.get_major_locator()) if sharex.xaxis.isDefault_minloc and not self.xaxis.isDefault_minloc: @@ -598,7 +598,7 @@ def _sharey_limits(self, sharey): ax1.set_yscale(ax2.get_yscale()) if ax1.get_autoscaley_on() and not ax2.get_autoscaley_on(): ax1.set_ylim(ax2.get_ylim()) - self.get_shared_y_axes().joined(self, sharey) # share limit/scale changes + self.sharey(sharey) if sharey.yaxis.isDefault_majloc and not self.yaxis.isDefault_majloc: sharey.yaxis.set_major_locator(self.yaxis.get_major_locator()) if sharey.yaxis.isDefault_minloc and not self.yaxis.isDefault_minloc: diff --git a/ultraplot/axes/shared.py b/ultraplot/axes/shared.py index 95b685744..d589e5c6e 100644 --- a/ultraplot/axes/shared.py +++ b/ultraplot/axes/shared.py @@ -10,6 +10,7 @@ from ..internals import ic # noqa: F401 from ..internals import _pop_kwargs from ..utils import _fontsize_to_pt, _not_none, units +from ..axes import Axes class _SharedAxes(object): @@ -184,3 +185,40 @@ def _update_ticks( if kwtext_extra: for lab in obj.get_ticklabels(): lab.update(kwtext_extra) + + # Override matplotlib defaults to handle multiple axis sharing + def sharex(self, other): + return self._share_axis(which="x", other=other) + + def sharey(self, other): + self._share_axis(which="y", other=other) + + # Ultraplot internal function to share axes + def _share_axis(self, which, other): + if not isinstance(other, Axes): + return TypeError( + f"Cannot share axes with {type(other).__name__}.\n" + f"Expected: ultraplot.base.Axes instance\n" + f"Received: {type(other).__name__}\n" + "Please provide a valid Axes instance to share with." + ) + + self._shared_axes[which].join(self, other) + + # Get axis objects + this_axis = getattr(self, f"{which}axis") + other_axis = getattr(other, f"{which}axis") + + # Set minor ticker + this_axis.minor = other_axis.minor + + # Get and set limits + limits = getattr(other, f"get_{which}lim")() + set_lim = getattr(self, f"set_{which}lim") + get_autoscale = getattr(other, f"get_autoscale{which}_on") + + lim0, lim1 = limits + set_lim(lim0, lim1, emit=False, auto=get_autoscale()) + + # Set scale + this_axis._scale = other_axis._scale diff --git a/ultraplot/tests/test_1dplots.py b/ultraplot/tests/test_1dplots.py index 60eed6d04..338d631e2 100644 --- a/ultraplot/tests/test_1dplots.py +++ b/ultraplot/tests/test_1dplots.py @@ -378,7 +378,6 @@ def test_scatter_sizes(): from matplotlib import tri -@pytest.mark.mpl_image_compare @pytest.mark.mpl_image_compare @pytest.mark.parametrize( "x, y, z, triangles, use_triangulation, use_datadict", diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 487920ce8..797068d3a 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -172,3 +172,47 @@ def test_reference_aspect(): fig.auto_layout() assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0]) return fig + + +@pytest.mark.mpl_image_compare +@pytest.mark.parametrize("share", ["limits", "labels"]) +def test_axis_sharing(share): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=share, span=False) + labels = ["A", "B", "C", "D"] + for idx, axi in enumerate(ax): + axi.scatter(idx, idx) + axi.set_xlabel(labels[idx]) + axi.set_ylabel(labels[idx]) + + # TODO: the labels are handled in a funky way. The plot looks fine but + # the label are not "shared" that is the labels still exist but they + # are not visible and instead there are new labels created. Need to + # figure this out. + # test left hand side + if share != "labels": + assert all([i == j for i, j in zip(ax[0].get_xlim(), ax[2].get_xlim())]) + assert all([i == j for i, j in zip(ax[0].get_ylim(), ax[1].get_ylim())]) + assert all([i == j for i, j in zip(ax[1].get_xlim(), ax[3].get_xlim())]) + elif share == "labels": + ax.draw( + fig.canvas.get_renderer() + ) # forcing a draw to ensure the labels are shared + # columns shares x label; top row should be empty + assert ax[0].xaxis.get_label().get_visible() == False + assert ax[1].xaxis.get_label().get_visible() == False + + assert ax[2].xaxis.get_label().get_visible() == True + assert ax[2].get_xlabel() == "A" + assert ax[3].xaxis.get_label().get_visible() == True + assert ax[3].get_xlabel() == "B" + + # rows share ylabel + assert ax[3].yaxis.get_label().get_visible() == False + assert ax[1].yaxis.get_label().get_visible() == False + + assert ax[0].yaxis.get_label().get_visible() == True + assert ax[2].yaxis.get_label().get_visible() == True + assert ax[0].get_ylabel() == "B" + assert ax[2].get_ylabel() == "D" + + return fig