Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pydatalab/src/pydatalab/apps/xrd/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def _make_plots(self, pattern_dfs: list[pd.DataFrame], y_options: list[str]):
return selectable_axes_plot(
pattern_dfs,
x_options=["2θ (°)", "Q (Å⁻¹)", "d (Å)"],
y_default="normalized intensity",
y_default="normalized intensity (staggered)"
if len(pattern_dfs) > 1
else "normalized intensity",
y_options=y_options,
plot_line=True,
plot_points=True,
Expand Down
38 changes: 28 additions & 10 deletions pydatalab/src/pydatalab/bokeh_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from bokeh.themes import Theme
from scipy.signal import find_peaks

from .utils import shrink_label
from .utils import generate_unique_labels

FONTSIZE = "12pt"
TYPEFACE = "Helvetica, sans-serif"
Expand Down Expand Up @@ -335,6 +335,27 @@ def selectable_axes_plot(
if isinstance(df, dict):
labels = list(df.keys())

original_labels_list = []

for ind, df_ in enumerate(df):
if isinstance(df, dict):
df_temp = df[df_]
else:
df_temp = df_

if labels:
orig = labels[ind]
else:
if hasattr(df_temp, "attrs") and "original_filename" in df_temp.attrs:
orig = df_temp.attrs["original_filename"] if len(df) > 1 else ""
else:
orig = df_temp.index.name if len(df) > 1 else ""

original_labels_list.append(orig)

legend_labels = (
generate_unique_labels(original_labels_list) if len(df) > 1 else original_labels_list
)
plot_columns = []

for ind, df_ in enumerate(df):
Expand All @@ -344,12 +365,12 @@ def selectable_axes_plot(
if isinstance(df, dict):
df_ = df[df_]

if labels:
label = labels[ind]
else:
label = df_.index.name if len(df) > 1 else ""
label = legend_labels[ind] if legend_labels else ""

label = shrink_label(label)
if hasattr(df_, "attrs"):
for attr in ["item_id", "original_filename", "wavelength"]:
if attr in df_.attrs:
df_[attr] = df_.attrs[attr]

source = ColumnDataSource(df_)

Expand Down Expand Up @@ -383,7 +404,6 @@ def selectable_axes_plot(
size=point_size,
line_color=color,
fill_color=fill_color,
legend_label=label,
hatch_pattern=hatch_patterns[ind % len(hatch_patterns)],
hatch_color=color,
)
Expand Down Expand Up @@ -412,7 +432,6 @@ def selectable_axes_plot(
y=y,
source=source,
color=color,
legend_label=label,
alpha=0.3,
)
if plot_line
Expand Down Expand Up @@ -457,9 +476,8 @@ def selectable_axes_plot(

external_legend = Legend(
items=legend_items,
click_policy="hide",
click_policy="none",
background_fill_alpha=0.8,
label_text_font_size="9pt",
spacing=1,
margin=2,
)
Expand Down
187 changes: 186 additions & 1 deletion pydatalab/src/pydatalab/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import datetime
import re
from json import JSONEncoder
from math import ceil

Expand Down Expand Up @@ -59,13 +60,33 @@ def default(o):


def shrink_label(label: str | None, max_length: int = 15) -> str:
"""Shrink label to fit within max_length, preserving file extension when possible."""
"""Shrink label to fit within max_length, preserving file extension."""
if not label or len(label) <= max_length:
return label or ""

if "." in label:
name, ext = label.rsplit(".", 1)
if len(ext) < 6:
pattern = r"(\d+)"
match = re.search(pattern, name)
if match:
number = match.group(1)
if len(number) > 4 and number.startswith("0"):
number_stripped = number.lstrip("0") or "0"
name_shortened = name[: match.start()] + number_stripped + name[match.end() :]

if len(name_shortened) + len(ext) + 1 <= max_length:
return f"{name_shortened}.{ext}"

number_with_ext_length = len(number_stripped) + len(ext) + 1
available_for_prefix = max_length - number_with_ext_length - 3
if available_for_prefix >= 2:
prefix = name[: match.start()][:available_for_prefix]
return f"{prefix}...{number_stripped}.{ext}"

if number_with_ext_length <= max_length:
return f"{number_stripped}.{ext}"

available = max_length - len(ext) - 4
if available > 3:
return f"{name[:available]}...{ext}"
Expand All @@ -75,3 +96,167 @@ def shrink_label(label: str | None, max_length: int = 15) -> str:
return f"{label[:12]}..."
else:
return f"{label[:12]}..."


def generate_unique_labels(
filenames: list[str],
max_length: int = 15,
) -> list[str]:
if not filenames or len(filenames) == 1:
return filenames if filenames else []

common_prefix = _find_common_prefix_smart(filenames)
common_suffix = _find_common_suffix_smart(filenames)

extension = ""
if all("." in f for f in filenames):
extensions = [f.rsplit(".", 1)[1] for f in filenames]
if len(set(extensions)) == 1:
extension = f".{extensions[0]}"

unique_parts = []
for filename in filenames:
start_idx = len(common_prefix)
end_idx = len(filename) - len(common_suffix)
unique_part = filename[start_idx:end_idx] if start_idx < end_idx else filename

if not unique_part.strip():
unique_part = filename

if extension and unique_part.endswith(extension):
unique_part = unique_part[: -len(extension)]

unique_part_without_ext = unique_part

if len(unique_part_without_ext) < 5 and len(filename) <= max_length:
unique_parts.append(filename)
else:
if extension:
unique_part = unique_part + extension
unique_parts.append(unique_part)

labels = []
for i, part in enumerate(unique_parts):
shrunken = shrink_label(part, max_length)

if "." in shrunken:
name_part, ext_part = shrunken.rsplit(".", 1)
if name_part.replace("0", "").isdigit() and name_part.startswith("0"):
stripped = name_part.lstrip("0") or "0"
if len(stripped) < 6 and common_prefix:
available = max_length - len(stripped) - len(ext_part) - 4
prefix_length = min(available, 4)
if prefix_length >= 2:
prefix_to_add = common_prefix[:prefix_length].rstrip("-_. /\\")
if prefix_to_add:
shrunken = f"{prefix_to_add}...{stripped}.{ext_part}"

if len(shrunken) < 8 and common_prefix:
available = max_length - len(shrunken) - 3
prefix_length = min(available, 4)
if prefix_length >= 2:
prefix_to_add = common_prefix[:prefix_length].rstrip("-_. /\\")
if prefix_to_add and "..." not in shrunken:
shrunken = f"{prefix_to_add}...{shrunken}"

labels.append(shrunken)

return _add_numbering_for_duplicates(labels)


def _find_common_prefix_smart(strings: list[str]) -> str:
if not strings or len(strings) < 2:
return ""

prefix = _find_common_prefix(strings)

if not prefix:
return ""

if prefix[-1] in ("-", "_", " ", "/", "\\"):
return prefix

last_sep = max(
prefix.rfind("-"),
prefix.rfind("_"),
prefix.rfind(" "),
prefix.rfind("/"),
prefix.rfind("\\"),
)

if last_sep > 0:
return prefix[: last_sep + 1]

return ""


def _find_common_suffix_smart(strings: list[str]) -> str:
if not strings or len(strings) < 2:
return ""

suffix = _find_common_suffix(strings)

if not suffix:
return ""

if suffix.startswith("."):
return ""

if suffix[0] in ("-", "_", " ", "/", "\\"):
return suffix

first_sep = len(suffix)
for sep in ("-", "_", " ", "/", "\\"):
pos = suffix.find(sep)
if pos != -1 and pos < first_sep:
first_sep = pos

if first_sep < len(suffix):
return suffix[first_sep:]

return ""


def _find_common_prefix(strings: list[str]) -> str:
if not strings or len(strings) < 2:
return ""

min_str = min(strings)
max_str = max(strings)

for i, char in enumerate(min_str):
if char != max_str[i]:
return min_str[:i]

return min_str


def _find_common_suffix(strings: list[str]) -> str:
if not strings or len(strings) < 2:
return ""

reversed_strings = [s[::-1] for s in strings]
common_reversed_prefix = _find_common_prefix(reversed_strings)

return common_reversed_prefix[::-1]


def _add_numbering_for_duplicates(labels: list[str]) -> list[str]:
label_counts: dict[str, int] = {}
for label in labels:
label_counts[label] = label_counts.get(label, 0) + 1

if all(count == 1 for count in label_counts.values()):
return labels

label_counter: dict[str, int] = {}
numbered_labels = []

for label in labels:
if label_counts[label] > 1:
label_counter[label] = label_counter.get(label, 0) + 1
numbered_labels.append(f"{label} [{label_counter[label]:02d}]")
else:
numbered_labels.append(label)

return numbered_labels
64 changes: 64 additions & 0 deletions pydatalab/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pydatalab.utils import generate_unique_labels


def test_generate_unique_labels_single_file():
result = generate_unique_labels(["sample_xrd_pattern.cif"])
assert result == ["sample_xrd_pattern.cif"]


def test_generate_unique_labels_empty():
result = generate_unique_labels([])
assert result == []


def test_generate_unique_labels_common_suffix():
filenames = ["sample1-xrd.xrdml", "sample2-xrd.xrdml"]
result = generate_unique_labels(filenames)
assert result == ["sample1.xrdml", "sample2.xrdml"]


def test_generate_unique_labels_prefix_and_suffix():
filenames = [
"experiment_run1_final.dat",
"experiment_run2_final.dat",
"experiment_run3_final.dat",
]
result = generate_unique_labels(filenames)
assert result == ["run1.dat", "run2.dat", "run3.dat"]


def test_generate_unique_labels_long_unique_parts():
filenames = [
"very_long_sample_name_with_many_characters_001.cif",
"very_long_sample_name_with_many_characters_002.cif",
]
result = generate_unique_labels(filenames, max_length=10)
assert all(len(label) <= 15 for label in result)
assert result[0] != result[1]


def test_generate_unique_labels_duplicates_after_shortening():
filenames = [
"CIF_0000000000000001.xrdml",
"CIF_0000000000000002.xrdml",
]
result = generate_unique_labels(filenames, max_length=15)
assert result == ["CIF...1.xrdml", "CIF...2.xrdml"]


def test_generate_unique_labels_common_prefix():
filenames = ["ICSDCollCode-000002.cif", "ICSDCollCode-000003.cif"]
result = generate_unique_labels(filenames)
assert result == ["ICSD...2.cif", "ICSD...3.cif"]


def test_generate_unique_labels_same_extension():
filenames = ["sample_A.cif", "sample_B.cif", "sample_C.cif"]
result = generate_unique_labels(filenames)
assert result == ["sample_A.cif", "sample_B.cif", "sample_C.cif"]


def test_generate_unique_labels_cif_pattern():
filenames = ["CIF_00000001.cif", "CIF_00000002.cif"]
result = generate_unique_labels(filenames)
assert result == ["CIF...1.cif", "CIF...2.cif"]