diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 5a4e5d1d..5d302f31 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -2704,10 +2704,17 @@ def legend( if ax is not None: # Check if span parameters are provided has_span = _not_none(span, row, col, rows, cols) is not None - # Extract a single axes from array if span is provided # Otherwise, pass the array as-is for normal legend behavior + # Automatically collect handles and labels from spanned axes if not provided if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): + # Auto-collect handles and labels if not explicitly provided + if handles is None and labels is None: + handles, labels = [], [] + for axi in ax: + h, l = axi.get_legend_handles_labels() + handles.extend(h) + labels.extend(l) try: ax_single = next(iter(ax)) except (TypeError, StopIteration): diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 48a40a67..6b984a55 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -483,3 +483,49 @@ def test_legend_multiple_sides_with_span(): assert leg_top is not None assert leg_right is not None assert leg_left is not None + + +def test_legend_auto_collect_handles_labels_with_span(): + """Test automatic collection of handles and labels from multiple axes with span parameters.""" + + fig, axs = uplt.subplots(nrows=2, ncols=2) + + # Create different plots in each subplot with labels + axs[0, 0].plot([0, 1], [0, 1], label="line1") + axs[0, 1].plot([0, 1], [1, 0], label="line2") + axs[1, 0].scatter([0.5], [0.5], label="point1") + axs[1, 1].scatter([0.5], [0.5], label="point2") + + # Test automatic collection with span parameter (no explicit handles/labels) + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created and contains all handles/labels from both axes + assert leg is not None + assert len(leg.get_texts()) == 2 # Should have 2 labels (line1, line2) + + # Test with rows parameter + leg2 = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + assert leg2 is not None + assert len(leg2.get_texts()) == 2 # Should have 2 labels (line1, point1) + + +def test_legend_explicit_handles_labels_override_auto_collection(): + """Test that explicit handles/labels override auto-collection.""" + + fig, axs = uplt.subplots(nrows=1, ncols=2) + + # Create plots with labels + (h1,) = axs[0].plot([0, 1], [0, 1], label="auto_label1") + (h2,) = axs[1].plot([0, 1], [1, 0], label="auto_label2") + + # Test with explicit handles/labels (should override auto-collection) + custom_handles = [h1] + custom_labels = ["custom_label"] + leg = fig.legend( + ax=axs, span=(1, 2), loc="bottom", handles=custom_handles, labels=custom_labels + ) + + # Verify legend uses explicit handles/labels, not auto-collected ones + assert leg is not None + assert len(leg.get_texts()) == 1 + assert leg.get_texts()[0].get_text() == "custom_label"