Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ultraplot/axes/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def _sharex_limits(self, sharex):
if ax1.get_autoscalex_on() and not ax2.get_autoscalex_on():
ax1.set_xlim(ax2.get_xlim()) # non-default limits
# Copy non-default locators and formatters
self.get_shared_x_axes().joined(self, sharex) # share limit/scale changes
self.sharex(sharex)
if sharex.xaxis.isDefault_majloc and not self.xaxis.isDefault_majloc:
sharex.xaxis.set_major_locator(self.xaxis.get_major_locator())
if sharex.xaxis.isDefault_minloc and not self.xaxis.isDefault_minloc:
Expand All @@ -598,7 +598,7 @@ def _sharey_limits(self, sharey):
ax1.set_yscale(ax2.get_yscale())
if ax1.get_autoscaley_on() and not ax2.get_autoscaley_on():
ax1.set_ylim(ax2.get_ylim())
self.get_shared_y_axes().joined(self, sharey) # share limit/scale changes
self.sharey(sharey)
if sharey.yaxis.isDefault_majloc and not self.yaxis.isDefault_majloc:
sharey.yaxis.set_major_locator(self.yaxis.get_major_locator())
if sharey.yaxis.isDefault_minloc and not self.yaxis.isDefault_minloc:
Expand Down
38 changes: 38 additions & 0 deletions ultraplot/axes/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..internals import ic # noqa: F401
from ..internals import _pop_kwargs
from ..utils import _fontsize_to_pt, _not_none, units
from ..axes import Axes


class _SharedAxes(object):
Expand Down Expand Up @@ -184,3 +185,40 @@ def _update_ticks(
if kwtext_extra:
for lab in obj.get_ticklabels():
lab.update(kwtext_extra)

# Override matplotlib defaults to handle multiple axis sharing
def sharex(self, other):
return self._share_axis(which="x", other=other)

def sharey(self, other):
self._share_axis(which="y", other=other)

# Ultraplot internal function to share axes
def _share_axis(self, which, other):
if not isinstance(other, Axes):
return TypeError(
f"Cannot share axes with {type(other).__name__}.\n"
f"Expected: ultraplot.base.Axes instance\n"
f"Received: {type(other).__name__}\n"
"Please provide a valid Axes instance to share with."
)

self._shared_axes[which].join(self, other)

# Get axis objects
this_axis = getattr(self, f"{which}axis")
other_axis = getattr(other, f"{which}axis")

# Set minor ticker
this_axis.minor = other_axis.minor

# Get and set limits
limits = getattr(other, f"get_{which}lim")()
set_lim = getattr(self, f"set_{which}lim")
get_autoscale = getattr(other, f"get_autoscale{which}_on")

lim0, lim1 = limits
set_lim(lim0, lim1, emit=False, auto=get_autoscale())

# Set scale
this_axis._scale = other_axis._scale
1 change: 0 additions & 1 deletion ultraplot/tests/test_1dplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ def test_scatter_sizes():
from matplotlib import tri


@pytest.mark.mpl_image_compare
@pytest.mark.mpl_image_compare
@pytest.mark.parametrize(
"x, y, z, triangles, use_triangulation, use_datadict",
Expand Down
44 changes: 44 additions & 0 deletions ultraplot/tests/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,47 @@ def test_reference_aspect():
fig.auto_layout()
assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0])
return fig


@pytest.mark.mpl_image_compare
@pytest.mark.parametrize("share", ["limits", "labels"])
def test_axis_sharing(share):
fig, ax = uplt.subplots(ncols=2, nrows=2, share=share, span=False)
labels = ["A", "B", "C", "D"]
for idx, axi in enumerate(ax):
axi.scatter(idx, idx)
axi.set_xlabel(labels[idx])
axi.set_ylabel(labels[idx])

# TODO: the labels are handled in a funky way. The plot looks fine but
# the label are not "shared" that is the labels still exist but they
# are not visible and instead there are new labels created. Need to
# figure this out.
# test left hand side
if share != "labels":
assert all([i == j for i, j in zip(ax[0].get_xlim(), ax[2].get_xlim())])
assert all([i == j for i, j in zip(ax[0].get_ylim(), ax[1].get_ylim())])
assert all([i == j for i, j in zip(ax[1].get_xlim(), ax[3].get_xlim())])
elif share == "labels":
ax.draw(
fig.canvas.get_renderer()
) # forcing a draw to ensure the labels are shared
# columns shares x label; top row should be empty
assert ax[0].xaxis.get_label().get_visible() == False
assert ax[1].xaxis.get_label().get_visible() == False

assert ax[2].xaxis.get_label().get_visible() == True
assert ax[2].get_xlabel() == "A"
assert ax[3].xaxis.get_label().get_visible() == True
assert ax[3].get_xlabel() == "B"

# rows share ylabel
assert ax[3].yaxis.get_label().get_visible() == False
assert ax[1].yaxis.get_label().get_visible() == False

assert ax[0].yaxis.get_label().get_visible() == True
assert ax[2].yaxis.get_label().get_visible() == True
assert ax[0].get_ylabel() == "B"
assert ax[2].get_ylabel() == "D"

return fig
Loading