Skip to content

Commit 29e518a

Browse files
BenjaminCharmesml-evs
authored andcommitted
Fix XRD legend labels to preserve distinguishing characters
1 parent 88ed143 commit 29e518a

File tree

3 files changed

+257
-18
lines changed

3 files changed

+257
-18
lines changed

pydatalab/src/pydatalab/bokeh_plots.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from bokeh.themes import Theme
2929
from scipy.signal import find_peaks
3030

31-
from .utils import shrink_label
31+
from .utils import generate_unique_labels
3232

3333
FONTSIZE = "12pt"
3434
TYPEFACE = "Helvetica, sans-serif"
@@ -335,7 +335,6 @@ def selectable_axes_plot(
335335
if isinstance(df, dict):
336336
labels = list(df.keys())
337337

338-
label_counts: dict[str, int] = {}
339338
original_labels_list = []
340339

341340
for ind, df_ in enumerate(df):
@@ -353,10 +352,10 @@ def selectable_axes_plot(
353352
orig = df_temp.index.name if len(df) > 1 else ""
354353

355354
original_labels_list.append(orig)
356-
shrunk = shrink_label(orig)
357-
label_counts[shrunk] = label_counts.get(shrunk, 0) + 1
358355

359-
label_counter: dict[str, int] = {}
356+
legend_labels = (
357+
generate_unique_labels(original_labels_list) if len(df) > 1 else original_labels_list
358+
)
360359
plot_columns = []
361360

362361
for ind, df_ in enumerate(df):
@@ -366,23 +365,14 @@ def selectable_axes_plot(
366365
if isinstance(df, dict):
367366
df_ = df[df_]
368367

369-
df_with_metadata = df_.copy()
370-
371-
original_label = original_labels_list[ind]
372-
label = shrink_label(original_label)
373-
374-
if label and len(df) > 1 and label_counts.get(label, 0) > 1:
375-
if label not in label_counter:
376-
label_counter[label] = 0
377-
label_counter[label] += 1
378-
label = f"{label} [{label_counter[label]:02d}]"
368+
label = legend_labels[ind] if legend_labels else ""
379369

380370
if hasattr(df_, "attrs"):
381371
for attr in ["item_id", "original_filename", "wavelength"]:
382372
if attr in df_.attrs:
383-
df_with_metadata[attr] = df_.attrs[attr]
373+
df_[attr] = df_.attrs[attr]
384374

385-
source = ColumnDataSource(df_with_metadata)
375+
source = ColumnDataSource(df_)
386376

387377
if color_options:
388378
color = {"field": color_options[0], "transform": color_mapper}

pydatalab/src/pydatalab/utils.py

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import datetime
7+
import re
78
from json import JSONEncoder
89
from math import ceil
910

@@ -59,13 +60,33 @@ def default(o):
5960

6061

6162
def shrink_label(label: str | None, max_length: int = 15) -> str:
62-
"""Shrink label to fit within max_length, preserving file extension when possible."""
63+
"""Shrink label to fit within max_length, preserving file extension."""
6364
if not label or len(label) <= max_length:
6465
return label or ""
6566

6667
if "." in label:
6768
name, ext = label.rsplit(".", 1)
6869
if len(ext) < 6:
70+
pattern = r"(\d+)"
71+
match = re.search(pattern, name)
72+
if match:
73+
number = match.group(1)
74+
if len(number) > 4 and number.startswith("0"):
75+
number_stripped = number.lstrip("0") or "0"
76+
name_shortened = name[: match.start()] + number_stripped + name[match.end() :]
77+
78+
if len(name_shortened) + len(ext) + 1 <= max_length:
79+
return f"{name_shortened}.{ext}"
80+
81+
number_with_ext_length = len(number_stripped) + len(ext) + 1
82+
available_for_prefix = max_length - number_with_ext_length - 3
83+
if available_for_prefix >= 2:
84+
prefix = name[: match.start()][:available_for_prefix]
85+
return f"{prefix}...{number_stripped}.{ext}"
86+
87+
if number_with_ext_length <= max_length:
88+
return f"{number_stripped}.{ext}"
89+
6990
available = max_length - len(ext) - 4
7091
if available > 3:
7192
return f"{name[:available]}...{ext}"
@@ -75,3 +96,167 @@ def shrink_label(label: str | None, max_length: int = 15) -> str:
7596
return f"{label[:12]}..."
7697
else:
7798
return f"{label[:12]}..."
99+
100+
101+
def generate_unique_labels(
102+
filenames: list[str],
103+
max_length: int = 15,
104+
) -> list[str]:
105+
if not filenames or len(filenames) == 1:
106+
return filenames if filenames else []
107+
108+
common_prefix = _find_common_prefix_smart(filenames)
109+
common_suffix = _find_common_suffix_smart(filenames)
110+
111+
extension = ""
112+
if all("." in f for f in filenames):
113+
extensions = [f.rsplit(".", 1)[1] for f in filenames]
114+
if len(set(extensions)) == 1:
115+
extension = f".{extensions[0]}"
116+
117+
unique_parts = []
118+
for filename in filenames:
119+
start_idx = len(common_prefix)
120+
end_idx = len(filename) - len(common_suffix)
121+
unique_part = filename[start_idx:end_idx] if start_idx < end_idx else filename
122+
123+
if not unique_part.strip():
124+
unique_part = filename
125+
126+
if extension and unique_part.endswith(extension):
127+
unique_part = unique_part[: -len(extension)]
128+
129+
unique_part_without_ext = unique_part
130+
131+
if len(unique_part_without_ext) < 5 and len(filename) <= max_length:
132+
unique_parts.append(filename)
133+
else:
134+
if extension:
135+
unique_part = unique_part + extension
136+
unique_parts.append(unique_part)
137+
138+
labels = []
139+
for i, part in enumerate(unique_parts):
140+
shrunken = shrink_label(part, max_length)
141+
142+
if "." in shrunken:
143+
name_part, ext_part = shrunken.rsplit(".", 1)
144+
if name_part.replace("0", "").isdigit() and name_part.startswith("0"):
145+
stripped = name_part.lstrip("0") or "0"
146+
if len(stripped) < 6 and common_prefix:
147+
available = max_length - len(stripped) - len(ext_part) - 4
148+
prefix_length = min(available, 4)
149+
if prefix_length >= 2:
150+
prefix_to_add = common_prefix[:prefix_length].rstrip("-_. /\\")
151+
if prefix_to_add:
152+
shrunken = f"{prefix_to_add}...{stripped}.{ext_part}"
153+
154+
if len(shrunken) < 8 and common_prefix:
155+
available = max_length - len(shrunken) - 3
156+
prefix_length = min(available, 4)
157+
if prefix_length >= 2:
158+
prefix_to_add = common_prefix[:prefix_length].rstrip("-_. /\\")
159+
if prefix_to_add and "..." not in shrunken:
160+
shrunken = f"{prefix_to_add}...{shrunken}"
161+
162+
labels.append(shrunken)
163+
164+
return _add_numbering_for_duplicates(labels)
165+
166+
167+
def _find_common_prefix_smart(strings: list[str]) -> str:
168+
if not strings or len(strings) < 2:
169+
return ""
170+
171+
prefix = _find_common_prefix(strings)
172+
173+
if not prefix:
174+
return ""
175+
176+
if prefix[-1] in ("-", "_", " ", "/", "\\"):
177+
return prefix
178+
179+
last_sep = max(
180+
prefix.rfind("-"),
181+
prefix.rfind("_"),
182+
prefix.rfind(" "),
183+
prefix.rfind("/"),
184+
prefix.rfind("\\"),
185+
)
186+
187+
if last_sep > 0:
188+
return prefix[: last_sep + 1]
189+
190+
return ""
191+
192+
193+
def _find_common_suffix_smart(strings: list[str]) -> str:
194+
if not strings or len(strings) < 2:
195+
return ""
196+
197+
suffix = _find_common_suffix(strings)
198+
199+
if not suffix:
200+
return ""
201+
202+
if suffix.startswith("."):
203+
return ""
204+
205+
if suffix[0] in ("-", "_", " ", "/", "\\"):
206+
return suffix
207+
208+
first_sep = len(suffix)
209+
for sep in ("-", "_", " ", "/", "\\"):
210+
pos = suffix.find(sep)
211+
if pos != -1 and pos < first_sep:
212+
first_sep = pos
213+
214+
if first_sep < len(suffix):
215+
return suffix[first_sep:]
216+
217+
return ""
218+
219+
220+
def _find_common_prefix(strings: list[str]) -> str:
221+
if not strings or len(strings) < 2:
222+
return ""
223+
224+
min_str = min(strings)
225+
max_str = max(strings)
226+
227+
for i, char in enumerate(min_str):
228+
if char != max_str[i]:
229+
return min_str[:i]
230+
231+
return min_str
232+
233+
234+
def _find_common_suffix(strings: list[str]) -> str:
235+
if not strings or len(strings) < 2:
236+
return ""
237+
238+
reversed_strings = [s[::-1] for s in strings]
239+
common_reversed_prefix = _find_common_prefix(reversed_strings)
240+
241+
return common_reversed_prefix[::-1]
242+
243+
244+
def _add_numbering_for_duplicates(labels: list[str]) -> list[str]:
245+
label_counts: dict[str, int] = {}
246+
for label in labels:
247+
label_counts[label] = label_counts.get(label, 0) + 1
248+
249+
if all(count == 1 for count in label_counts.values()):
250+
return labels
251+
252+
label_counter: dict[str, int] = {}
253+
numbered_labels = []
254+
255+
for label in labels:
256+
if label_counts[label] > 1:
257+
label_counter[label] = label_counter.get(label, 0) + 1
258+
numbered_labels.append(f"{label} [{label_counter[label]:02d}]")
259+
else:
260+
numbered_labels.append(label)
261+
262+
return numbered_labels

pydatalab/tests/test_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from pydatalab.utils import generate_unique_labels
2+
3+
4+
def test_generate_unique_labels_single_file():
5+
result = generate_unique_labels(["sample_xrd_pattern.cif"])
6+
assert result == ["sample_xrd_pattern.cif"]
7+
8+
9+
def test_generate_unique_labels_empty():
10+
result = generate_unique_labels([])
11+
assert result == []
12+
13+
14+
def test_generate_unique_labels_common_suffix():
15+
filenames = ["sample1-xrd.xrdml", "sample2-xrd.xrdml"]
16+
result = generate_unique_labels(filenames)
17+
assert result == ["sample1.xrdml", "sample2.xrdml"]
18+
19+
20+
def test_generate_unique_labels_prefix_and_suffix():
21+
filenames = [
22+
"experiment_run1_final.dat",
23+
"experiment_run2_final.dat",
24+
"experiment_run3_final.dat",
25+
]
26+
result = generate_unique_labels(filenames)
27+
assert result == ["run1.dat", "run2.dat", "run3.dat"]
28+
29+
30+
def test_generate_unique_labels_long_unique_parts():
31+
filenames = [
32+
"very_long_sample_name_with_many_characters_001.cif",
33+
"very_long_sample_name_with_many_characters_002.cif",
34+
]
35+
result = generate_unique_labels(filenames, max_length=10)
36+
assert all(len(label) <= 15 for label in result)
37+
assert result[0] != result[1]
38+
39+
40+
def test_generate_unique_labels_duplicates_after_shortening():
41+
filenames = [
42+
"CIF_0000000000000001.xrdml",
43+
"CIF_0000000000000002.xrdml",
44+
]
45+
result = generate_unique_labels(filenames, max_length=15)
46+
assert result == ["CIF...1.xrdml", "CIF...2.xrdml"]
47+
48+
49+
def test_generate_unique_labels_common_prefix():
50+
filenames = ["ICSDCollCode-000002.cif", "ICSDCollCode-000003.cif"]
51+
result = generate_unique_labels(filenames)
52+
assert result == ["ICSD...2.cif", "ICSD...3.cif"]
53+
54+
55+
def test_generate_unique_labels_same_extension():
56+
filenames = ["sample_A.cif", "sample_B.cif", "sample_C.cif"]
57+
result = generate_unique_labels(filenames)
58+
assert result == ["sample_A.cif", "sample_B.cif", "sample_C.cif"]
59+
60+
61+
def test_generate_unique_labels_cif_pattern():
62+
filenames = ["CIF_00000001.cif", "CIF_00000002.cif"]
63+
result = generate_unique_labels(filenames)
64+
assert result == ["CIF...1.cif", "CIF...2.cif"]

0 commit comments

Comments
 (0)