From 52b0cb886a542a5789fce14130edff00be0ce0a3 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 17 Jan 2026 20:20:54 +0100 Subject: [PATCH] fix issues in plotting tests --- tests/test_plot.py | 96 ++++++++++++++++++++++------------------------ 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/tests/test_plot.py b/tests/test_plot.py index 2bafa42b..11d31828 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -190,8 +190,8 @@ def test_compare_models_with_different_versions(matplotlib_version): # minimum version of matplotlib minimum_version = "3.6" - if packaging.version.parse( - matplotlib_version) < packaging.version.parse(minimum_version): + if packaging.version.parse(matplotlib_version) < packaging.version.parse( + minimum_version): with pytest.raises(ImportError): cebra_plot.compare_models(models=fitted_models, patched_version=matplotlib_version) @@ -359,9 +359,6 @@ def test_plot_consistency(): dataset_ids = ["achilles", "buddy", "cicero", "gatsby"] - figure = plt.figure(figsize=(5, 5)) - ax = figure.add_subplot() - scores_subs, pairs_subs, datasets_subs = cebra_sklearn_metrics.consistency_score( embeddings_datasets, labels=labels_datasets, @@ -371,17 +368,21 @@ def test_plot_consistency(): scores_runs, pairs_runs, datasets_runs = cebra_sklearn_metrics.consistency_score( embeddings_runs, between="runs") + # ------------------------------------------------------------ + # between datasets - fig = cebra_plot.plot_consistency(scores_subs, - pairs=pairs_subs, - datasets=datasets_subs) - assert isinstance(fig, matplotlib.axes.Axes) + ax = cebra_plot.plot_consistency(scores_subs, + pairs=pairs_subs, + datasets=datasets_subs) + assert isinstance(ax, matplotlib.axes.Axes) plt.close() + ax = cebra_plot.plot_consistency(scores_subs, pairs=pairs_subs, - datasets=datasets_subs, - ax=ax) + datasets=datasets_subs) assert isinstance(ax, matplotlib.axes.Axes) + plt.close() + ax = cebra_plot.plot_consistency( torch.from_numpy(scores_subs), pairs=pairs_subs, @@ -390,119 +391,114 @@ def test_plot_consistency(): title="Test", text_color=None, colorbar_label=None, - ax=ax, ) - assert isinstance(fig, matplotlib.axes.Axes) + assert isinstance(ax, matplotlib.axes.Axes) + plt.close() ax = cebra_plot.plot_consistency(torch.from_numpy(scores_subs), pairs=pairs_subs, - datasets=datasets_subs, - ax=ax) - assert isinstance(fig, matplotlib.axes.Axes) + datasets=datasets_subs) + assert isinstance(ax, matplotlib.axes.Axes) + plt.close() ax = cebra_plot.plot_consistency( scores_subs.tolist(), pairs=pairs_subs.tolist(), datasets=datasets_subs.tolist(), - ax=ax, ) - assert isinstance(fig, matplotlib.axes.Axes) + assert isinstance(ax, matplotlib.axes.Axes) + plt.close() with pytest.raises(ValueError, match="Missing.*datasets.*pairs"): - _ = cebra_plot.plot_consistency(scores_subs, ax=ax) + _ = cebra_plot.plot_consistency(scores_subs) with pytest.raises(ValueError, match="Missing.*datasets.*pairs"): - _ = cebra_plot.plot_consistency(scores_subs, pairs=pairs_subs, ax=ax) + _ = cebra_plot.plot_consistency(scores_subs, pairs=pairs_subs) with pytest.raises(ValueError, match="Missing.*datasets.*pairs"): - _ = cebra_plot.plot_consistency(scores_subs, - datasets=datasets_subs, - ax=ax) + _ = cebra_plot.plot_consistency(scores_subs, datasets=datasets_subs) with pytest.raises(ValueError, match="Shape.*pairs"): _ = cebra_plot.plot_consistency( scores_subs, pairs=np.random.uniform(0, 1, (10, 2)), datasets=datasets_subs, - ax=ax, ) with pytest.raises(ValueError, match="Shape.*datasets"): _ = cebra_plot.plot_consistency( scores_subs, pairs=np.random.uniform(0, 1, (10, 2)), datasets=np.random.uniform(0, 1, (2,)), - ax=ax, ) with pytest.raises(ValueError, match="Invalid.*scores"): _ = cebra_plot.plot_consistency( np.random.uniform(0, 1, (12, 2, 2)), pairs=pairs_subs, datasets=datasets_subs, - ax=ax, ) + plt.close("all") # between runs - fig = cebra_plot.plot_consistency(scores_runs, - pairs=pairs_runs, - datasets=datasets_runs) - assert isinstance(fig, matplotlib.axes.Axes) - plt.close() ax = cebra_plot.plot_consistency(scores_runs, pairs=pairs_runs, - datasets=datasets_runs, - ax=ax) + datasets=datasets_runs) assert isinstance(ax, matplotlib.axes.Axes) + plt.close() + ax = cebra_plot.plot_consistency( scores_runs, pairs=pairs_runs, datasets=datasets_runs, - cmap="viridis", - title="Test", - text_color=None, - colorbar_label=None, - ax=ax, ) assert isinstance(ax, matplotlib.axes.Axes) - ax = cebra_plot.plot_consistency(torch.from_numpy(scores_runs), + plt.close() + + ax = cebra_plot.plot_consistency(scores_runs, pairs=pairs_runs, datasets=datasets_runs, - ax=ax) + cmap="viridis", + title="Test", + text_color=None, + colorbar_label=None) assert isinstance(ax, matplotlib.axes.Axes) + plt.close() + + ax = cebra_plot.plot_consistency(torch.from_numpy(scores_runs), + pairs=pairs_runs, + datasets=datasets_runs) + assert isinstance(ax, matplotlib.axes.Axes) + plt.close() + ax = cebra_plot.plot_consistency( scores_runs.tolist(), pairs=pairs_runs.tolist(), datasets=datasets_runs.tolist(), - ax=ax, ) assert isinstance(ax, matplotlib.axes.Axes) + plt.close() with pytest.raises(ValueError, match="Missing.*datasets.*pairs"): - _ = cebra_plot.plot_consistency(scores_runs, ax=ax) + _ = cebra_plot.plot_consistency(scores_runs) with pytest.raises(ValueError, match="Missing.*datasets.*pairs"): - _ = cebra_plot.plot_consistency(scores_runs, pairs=pairs_runs, ax=ax) + _ = cebra_plot.plot_consistency(scores_runs, pairs=pairs_runs) with pytest.raises(ValueError, match="Missing.*datasets.*pairs"): - _ = cebra_plot.plot_consistency(scores_runs, - datasets=datasets_runs, - ax=ax) + _ = cebra_plot.plot_consistency(scores_runs, datasets=datasets_runs) with pytest.raises(ValueError, match="Shape.*datasets"): _ = cebra_plot.plot_consistency( scores_runs, pairs=np.random.uniform(0, 1, (10, 2)), datasets=np.random.uniform(0, 1, (4,)), - ax=ax, ) with pytest.raises(ValueError, match="Shape.*pairs"): _ = cebra_plot.plot_consistency( scores_runs, pairs=np.random.uniform(0, 1, (10, 2)), datasets=datasets_runs, - ax=ax, ) with pytest.raises(ValueError, match="Invalid.*dimensions"): _ = cebra_plot.plot_consistency( np.random.uniform(0, 1, (12, 2, 2)), pairs=pairs_runs, datasets=datasets_runs, - ax=ax, ) - plt.close() + plt.close("all") @pytest.mark.parametrize("seed", [None, 42, 1024, 454545])