diff --git a/pydatalab/src/pydatalab/apps/xrd/blocks.py b/pydatalab/src/pydatalab/apps/xrd/blocks.py index 2bbc68564..c4744ef11 100644 --- a/pydatalab/src/pydatalab/apps/xrd/blocks.py +++ b/pydatalab/src/pydatalab/apps/xrd/blocks.py @@ -322,7 +322,9 @@ def _make_plots(self, pattern_dfs: list[pd.DataFrame], y_options: list[str]): return selectable_axes_plot( pattern_dfs, x_options=["2θ (°)", "Q (Å⁻¹)", "d (Å)"], - y_default="normalized intensity", + y_default="normalized intensity (staggered)" + if len(pattern_dfs) > 1 + else "normalized intensity", y_options=y_options, plot_line=True, plot_points=True, diff --git a/pydatalab/src/pydatalab/bokeh_plots.py b/pydatalab/src/pydatalab/bokeh_plots.py index 82cb303da..6d5554633 100644 --- a/pydatalab/src/pydatalab/bokeh_plots.py +++ b/pydatalab/src/pydatalab/bokeh_plots.py @@ -28,7 +28,7 @@ from bokeh.themes import Theme from scipy.signal import find_peaks -from .utils import shrink_label +from .utils import generate_unique_labels FONTSIZE = "12pt" TYPEFACE = "Helvetica, sans-serif" @@ -335,6 +335,27 @@ def selectable_axes_plot( if isinstance(df, dict): labels = list(df.keys()) + original_labels_list = [] + + for ind, df_ in enumerate(df): + if isinstance(df, dict): + df_temp = df[df_] + else: + df_temp = df_ + + if labels: + orig = labels[ind] + else: + if hasattr(df_temp, "attrs") and "original_filename" in df_temp.attrs: + orig = df_temp.attrs["original_filename"] if len(df) > 1 else "" + else: + orig = df_temp.index.name if len(df) > 1 else "" + + original_labels_list.append(orig) + + legend_labels = ( + generate_unique_labels(original_labels_list) if len(df) > 1 else original_labels_list + ) plot_columns = [] for ind, df_ in enumerate(df): @@ -344,12 +365,12 @@ def selectable_axes_plot( if isinstance(df, dict): df_ = df[df_] - if labels: - label = labels[ind] - else: - label = df_.index.name if len(df) > 1 else "" + label = legend_labels[ind] if legend_labels else "" - label = shrink_label(label) + if hasattr(df_, "attrs"): + for attr in ["item_id", "original_filename", "wavelength"]: + if attr in df_.attrs: + df_[attr] = df_.attrs[attr] source = ColumnDataSource(df_) @@ -383,7 +404,6 @@ def selectable_axes_plot( size=point_size, line_color=color, fill_color=fill_color, - legend_label=label, hatch_pattern=hatch_patterns[ind % len(hatch_patterns)], hatch_color=color, ) @@ -412,7 +432,6 @@ def selectable_axes_plot( y=y, source=source, color=color, - legend_label=label, alpha=0.3, ) if plot_line @@ -457,9 +476,8 @@ def selectable_axes_plot( external_legend = Legend( items=legend_items, - click_policy="hide", + click_policy="none", background_fill_alpha=0.8, - label_text_font_size="9pt", spacing=1, margin=2, ) diff --git a/pydatalab/src/pydatalab/utils.py b/pydatalab/src/pydatalab/utils.py index 7ef676062..1cb8c5d0a 100644 --- a/pydatalab/src/pydatalab/utils.py +++ b/pydatalab/src/pydatalab/utils.py @@ -4,6 +4,7 @@ """ import datetime +import re from json import JSONEncoder from math import ceil @@ -59,13 +60,33 @@ def default(o): def shrink_label(label: str | None, max_length: int = 15) -> str: - """Shrink label to fit within max_length, preserving file extension when possible.""" + """Shrink label to fit within max_length, preserving file extension.""" if not label or len(label) <= max_length: return label or "" if "." in label: name, ext = label.rsplit(".", 1) if len(ext) < 6: + pattern = r"(\d+)" + match = re.search(pattern, name) + if match: + number = match.group(1) + if len(number) > 4 and number.startswith("0"): + number_stripped = number.lstrip("0") or "0" + name_shortened = name[: match.start()] + number_stripped + name[match.end() :] + + if len(name_shortened) + len(ext) + 1 <= max_length: + return f"{name_shortened}.{ext}" + + number_with_ext_length = len(number_stripped) + len(ext) + 1 + available_for_prefix = max_length - number_with_ext_length - 3 + if available_for_prefix >= 2: + prefix = name[: match.start()][:available_for_prefix] + return f"{prefix}...{number_stripped}.{ext}" + + if number_with_ext_length <= max_length: + return f"{number_stripped}.{ext}" + available = max_length - len(ext) - 4 if available > 3: return f"{name[:available]}...{ext}" @@ -75,3 +96,167 @@ def shrink_label(label: str | None, max_length: int = 15) -> str: return f"{label[:12]}..." else: return f"{label[:12]}..." + + +def generate_unique_labels( + filenames: list[str], + max_length: int = 15, +) -> list[str]: + if not filenames or len(filenames) == 1: + return filenames if filenames else [] + + common_prefix = _find_common_prefix_smart(filenames) + common_suffix = _find_common_suffix_smart(filenames) + + extension = "" + if all("." in f for f in filenames): + extensions = [f.rsplit(".", 1)[1] for f in filenames] + if len(set(extensions)) == 1: + extension = f".{extensions[0]}" + + unique_parts = [] + for filename in filenames: + start_idx = len(common_prefix) + end_idx = len(filename) - len(common_suffix) + unique_part = filename[start_idx:end_idx] if start_idx < end_idx else filename + + if not unique_part.strip(): + unique_part = filename + + if extension and unique_part.endswith(extension): + unique_part = unique_part[: -len(extension)] + + unique_part_without_ext = unique_part + + if len(unique_part_without_ext) < 5 and len(filename) <= max_length: + unique_parts.append(filename) + else: + if extension: + unique_part = unique_part + extension + unique_parts.append(unique_part) + + labels = [] + for i, part in enumerate(unique_parts): + shrunken = shrink_label(part, max_length) + + if "." in shrunken: + name_part, ext_part = shrunken.rsplit(".", 1) + if name_part.replace("0", "").isdigit() and name_part.startswith("0"): + stripped = name_part.lstrip("0") or "0" + if len(stripped) < 6 and common_prefix: + available = max_length - len(stripped) - len(ext_part) - 4 + prefix_length = min(available, 4) + if prefix_length >= 2: + prefix_to_add = common_prefix[:prefix_length].rstrip("-_. /\\") + if prefix_to_add: + shrunken = f"{prefix_to_add}...{stripped}.{ext_part}" + + if len(shrunken) < 8 and common_prefix: + available = max_length - len(shrunken) - 3 + prefix_length = min(available, 4) + if prefix_length >= 2: + prefix_to_add = common_prefix[:prefix_length].rstrip("-_. /\\") + if prefix_to_add and "..." not in shrunken: + shrunken = f"{prefix_to_add}...{shrunken}" + + labels.append(shrunken) + + return _add_numbering_for_duplicates(labels) + + +def _find_common_prefix_smart(strings: list[str]) -> str: + if not strings or len(strings) < 2: + return "" + + prefix = _find_common_prefix(strings) + + if not prefix: + return "" + + if prefix[-1] in ("-", "_", " ", "/", "\\"): + return prefix + + last_sep = max( + prefix.rfind("-"), + prefix.rfind("_"), + prefix.rfind(" "), + prefix.rfind("/"), + prefix.rfind("\\"), + ) + + if last_sep > 0: + return prefix[: last_sep + 1] + + return "" + + +def _find_common_suffix_smart(strings: list[str]) -> str: + if not strings or len(strings) < 2: + return "" + + suffix = _find_common_suffix(strings) + + if not suffix: + return "" + + if suffix.startswith("."): + return "" + + if suffix[0] in ("-", "_", " ", "/", "\\"): + return suffix + + first_sep = len(suffix) + for sep in ("-", "_", " ", "/", "\\"): + pos = suffix.find(sep) + if pos != -1 and pos < first_sep: + first_sep = pos + + if first_sep < len(suffix): + return suffix[first_sep:] + + return "" + + +def _find_common_prefix(strings: list[str]) -> str: + if not strings or len(strings) < 2: + return "" + + min_str = min(strings) + max_str = max(strings) + + for i, char in enumerate(min_str): + if char != max_str[i]: + return min_str[:i] + + return min_str + + +def _find_common_suffix(strings: list[str]) -> str: + if not strings or len(strings) < 2: + return "" + + reversed_strings = [s[::-1] for s in strings] + common_reversed_prefix = _find_common_prefix(reversed_strings) + + return common_reversed_prefix[::-1] + + +def _add_numbering_for_duplicates(labels: list[str]) -> list[str]: + label_counts: dict[str, int] = {} + for label in labels: + label_counts[label] = label_counts.get(label, 0) + 1 + + if all(count == 1 for count in label_counts.values()): + return labels + + label_counter: dict[str, int] = {} + numbered_labels = [] + + for label in labels: + if label_counts[label] > 1: + label_counter[label] = label_counter.get(label, 0) + 1 + numbered_labels.append(f"{label} [{label_counter[label]:02d}]") + else: + numbered_labels.append(label) + + return numbered_labels diff --git a/pydatalab/tests/test_utils.py b/pydatalab/tests/test_utils.py new file mode 100644 index 000000000..38291e341 --- /dev/null +++ b/pydatalab/tests/test_utils.py @@ -0,0 +1,64 @@ +from pydatalab.utils import generate_unique_labels + + +def test_generate_unique_labels_single_file(): + result = generate_unique_labels(["sample_xrd_pattern.cif"]) + assert result == ["sample_xrd_pattern.cif"] + + +def test_generate_unique_labels_empty(): + result = generate_unique_labels([]) + assert result == [] + + +def test_generate_unique_labels_common_suffix(): + filenames = ["sample1-xrd.xrdml", "sample2-xrd.xrdml"] + result = generate_unique_labels(filenames) + assert result == ["sample1.xrdml", "sample2.xrdml"] + + +def test_generate_unique_labels_prefix_and_suffix(): + filenames = [ + "experiment_run1_final.dat", + "experiment_run2_final.dat", + "experiment_run3_final.dat", + ] + result = generate_unique_labels(filenames) + assert result == ["run1.dat", "run2.dat", "run3.dat"] + + +def test_generate_unique_labels_long_unique_parts(): + filenames = [ + "very_long_sample_name_with_many_characters_001.cif", + "very_long_sample_name_with_many_characters_002.cif", + ] + result = generate_unique_labels(filenames, max_length=10) + assert all(len(label) <= 15 for label in result) + assert result[0] != result[1] + + +def test_generate_unique_labels_duplicates_after_shortening(): + filenames = [ + "CIF_0000000000000001.xrdml", + "CIF_0000000000000002.xrdml", + ] + result = generate_unique_labels(filenames, max_length=15) + assert result == ["CIF...1.xrdml", "CIF...2.xrdml"] + + +def test_generate_unique_labels_common_prefix(): + filenames = ["ICSDCollCode-000002.cif", "ICSDCollCode-000003.cif"] + result = generate_unique_labels(filenames) + assert result == ["ICSD...2.cif", "ICSD...3.cif"] + + +def test_generate_unique_labels_same_extension(): + filenames = ["sample_A.cif", "sample_B.cif", "sample_C.cif"] + result = generate_unique_labels(filenames) + assert result == ["sample_A.cif", "sample_B.cif", "sample_C.cif"] + + +def test_generate_unique_labels_cif_pattern(): + filenames = ["CIF_00000001.cif", "CIF_00000002.cif"] + result = generate_unique_labels(filenames) + assert result == ["CIF...1.cif", "CIF...2.cif"]