Skip to content

Commit 1e99975

Browse files
BenjaminCharmesml-evs
authored andcommitted
Fix XRD legend issues with duplicate labels and point/line coupling
1 parent 60e25bf commit 1e99975

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

pydatalab/src/pydatalab/bokeh_plots.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,28 @@ def selectable_axes_plot(
335335
if isinstance(df, dict):
336336
labels = list(df.keys())
337337

338+
label_counts: dict[str, int] = {}
339+
original_labels_list = []
340+
341+
for ind, df_ in enumerate(df):
342+
if isinstance(df, dict):
343+
df_temp = df[df_]
344+
else:
345+
df_temp = df_
346+
347+
if labels:
348+
orig = labels[ind]
349+
else:
350+
if hasattr(df_temp, "attrs") and "original_filename" in df_temp.attrs:
351+
orig = df_temp.attrs["original_filename"] if len(df) > 1 else ""
352+
else:
353+
orig = df_temp.index.name if len(df) > 1 else ""
354+
355+
original_labels_list.append(orig)
356+
shrunk = shrink_label(orig)
357+
label_counts[shrunk] = label_counts.get(shrunk, 0) + 1
358+
359+
label_counter: dict[str, int] = {}
338360
plot_columns = []
339361

340362
for ind, df_ in enumerate(df):
@@ -344,14 +366,23 @@ def selectable_axes_plot(
344366
if isinstance(df, dict):
345367
df_ = df[df_]
346368

347-
if labels:
348-
label = labels[ind]
349-
else:
350-
label = df_.index.name if len(df) > 1 else ""
369+
df_with_metadata = df_.copy()
351370

352-
label = shrink_label(label)
371+
original_label = original_labels_list[ind]
372+
label = shrink_label(original_label)
353373

354-
source = ColumnDataSource(df_)
374+
if label and len(df) > 1 and label_counts.get(label, 0) > 1:
375+
if label not in label_counter:
376+
label_counter[label] = 0
377+
label_counter[label] += 1
378+
label = f"{label} [{label_counter[label]:02d}]"
379+
380+
if hasattr(df_, "attrs"):
381+
for attr in ["item_id", "original_filename", "wavelength"]:
382+
if attr in df_.attrs:
383+
df_with_metadata[attr] = df_.attrs[attr]
384+
385+
source = ColumnDataSource(df_with_metadata)
355386

356387
if color_options:
357388
color = {"field": color_options[0], "transform": color_mapper}
@@ -383,7 +414,6 @@ def selectable_axes_plot(
383414
size=point_size,
384415
line_color=color,
385416
fill_color=fill_color,
386-
legend_label=label,
387417
hatch_pattern=hatch_patterns[ind % len(hatch_patterns)],
388418
hatch_color=color,
389419
)
@@ -412,7 +442,6 @@ def selectable_axes_plot(
412442
y=y,
413443
source=source,
414444
color=color,
415-
legend_label=label,
416445
alpha=0.3,
417446
)
418447
if plot_line

0 commit comments

Comments
 (0)