From 1082105217b09a88eee6cc32f6676f4f96fae607 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Wed, 27 Aug 2025 11:09:11 +0900 Subject: [PATCH 01/30] ENH: Support arbitrary number of wavelengths (>=2) in NIRS/SNIRF processing --- mne/io/snirf/_snirf.py | 7 +- mne/preprocessing/nirs/_beer_lambert_law.py | 97 ++++++++++++---- .../nirs/_scalp_coupling_index.py | 51 +++++++-- mne/preprocessing/nirs/nirs.py | 107 +++++++++++------- 4 files changed, 185 insertions(+), 77 deletions(-) diff --git a/mne/io/snirf/_snirf.py b/mne/io/snirf/_snirf.py index 55f3c54605c..cfb9c8f71e2 100644 --- a/mne/io/snirf/_snirf.py +++ b/mne/io/snirf/_snirf.py @@ -148,14 +148,13 @@ def __init__( # Extract wavelengths fnirs_wavelengths = np.array(dat.get("nirs/probe/wavelengths")) fnirs_wavelengths = [int(w) for w in fnirs_wavelengths] - if len(fnirs_wavelengths) != 2: + if len(fnirs_wavelengths) < 2: raise RuntimeError( f"The data contains " f"{len(fnirs_wavelengths)}" f" wavelengths: {fnirs_wavelengths}. " - f"MNE only supports reading continuous" - " wave amplitude SNIRF files " - "with two wavelengths." + f"MNE requires at least two wavelengths for " + "continuous wave amplitude SNIRF files." ) # Extract channels diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index c17cf31110c..ec5d20842fd 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -36,23 +36,32 @@ def beer_lambert_law(raw, ppf=6.0): _validate_type(raw, BaseRaw, "raw") _validate_type(ppf, ("numeric", "array-like"), "ppf") ppf = np.array(ppf, float) - if ppf.ndim == 0: # upcast single float to shape (2,) - ppf = np.array([ppf, ppf]) - if ppf.shape != (2,): - raise ValueError( - f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}" - ) - ppf = ppf[:, np.newaxis] # shape (2, 1) picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert") # This is the one place we *really* need the actual/accurate frequencies freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) - abs_coef = _load_absorption(freqs) + + # Get unique wavelengths and determine number of wavelengths + unique_freqs = np.unique(freqs) + n_wavelengths = len(unique_freqs) + + # PPF validation for multiple wavelengths + if ppf.ndim == 0: # single float + ppf = np.full((n_wavelengths, 1), ppf) # same PPF for all wavelengths, shape (n_wavelengths, 1) + elif ppf.ndim == 1 and len(ppf) == n_wavelengths: # separate ppf for each wavelength + ppf = ppf[:, np.newaxis] # shape (n_wavelengths, 1) + else: + raise ValueError( + f"ppf must be a single float or an array-like of length {n_wavelengths} " + f"(number of wavelengths), got shape {ppf.shape}" + ) + + abs_coef = _load_absorption(unique_freqs) # shape (n_wavelengths, 2) distances = source_detector_distances(raw.info, picks="all") bad = ~np.isfinite(distances[picks]) bad |= distances[picks] <= 0 if bad.any(): warn( - "Source-detector distances are zero on NaN, some resulting " + "Source-detector distances are zero or NaN, some resulting " "concentrations will be zero. Consider setting a montage " "with raw.set_montage." ) @@ -64,20 +73,43 @@ def beer_lambert_law(raw, ppf=6.0): "likely due to optode locations being stored in a " " unit other than meters." ) + rename = dict() - for ii, jj in zip(picks[::2], picks[1::2]): - EL = abs_coef * distances[ii] * ppf - iEL = pinv(EL) + channels_to_drop_all = [] # Accumulate all channels to drop + + # Iterate over channel groups ([Si_Di all wavelengths, Sj_Dj all wavelengths, ...]) + pick_groups = zip(*[iter(picks)] * n_wavelengths) + for group_picks in pick_groups: - raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3 + # Calculate Δc based on the system: ΔOD = E * L * PPF * Δc + # where E is (n_wavelengths, 2), Δc is (2, n_timepoints) + # using pseudo-inverse + EL = abs_coef * distances[group_picks[0]] * ppf + iEL = pinv(EL) # Pseudo-inverse for numerical stability + conc_data = iEL @ raw._data[group_picks] * 1e-3 + + # Replace the first two channels with HbO and HbR + raw._data[group_picks[:2]] = conc_data[:2] # HbO, HbR # Update channel information coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO, hbr=FIFF.FIFFV_COIL_FNIRS_HBR) - for ki, kind in zip((ii, jj), ("hbo", "hbr")): + for ki, kind in zip(group_picks[:2], ("hbo", "hbr")): ch = raw.info["chs"][ki] ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL) new_name = f"{ch['ch_name'].split(' ')[0]} {kind}" rename[ch["ch_name"]] = new_name + + # Accumulate extra wavelength channels to drop (keep only HbO and HbR) + if n_wavelengths > 2: + channels_to_drop = group_picks[2:] + channel_names_to_drop = [raw.ch_names[idx] for idx in channels_to_drop] + channels_to_drop_all.extend(channel_names_to_drop) + + # Drop all accumulated extra wavelength channels after processing all groups + # This preserves channel indexing during the loop iterations + if channels_to_drop_all: + raw.drop_channels(channels_to_drop_all) + raw.rename_channels(rename) # Validate the format of data after transformation is valid @@ -86,16 +118,31 @@ def beer_lambert_law(raw, ppf=6.0): def _load_absorption(freqs): - """Load molar extinction coefficients.""" + """Load molar extinction coefficients + + Parameters + ---------- + freqs : array-like + Array of wavelengths (frequencies) in nm. + + Returns + ------- + abs_coef : array, shape (n_wavelengths, 2) + Absorption coefficients for HbO2 and Hb at each wavelength. + abs_coef[:, 0] contains HbO2 coefficients + abs_coef[:, 1] contains Hb coefficients + + E.g. [[HbO2(freq1)], [Hb(freq1)], + [HbO2(freq2)], [Hb(freq2)], + ..., + [HbO2(freqN)], [Hb(freqN)]] + """ # Data from https://omlc.org/spectra/hemoglobin/summary.html # The text was copied to a text file. The text before and # after the table was deleted. The the following was run in # matlab # extinct_coef=importdata('extinction_coef.txt') # save('extinction_coef.mat', 'extinct_coef') - # - # Returns data as [[HbO2(freq1), Hb(freq1)], - # [HbO2(freq2), Hb(freq2)]] extinction_fname = op.join( op.dirname(__file__), "..", "..", "data", "extinction_coef.mat" ) @@ -104,12 +151,12 @@ def _load_absorption(freqs): interp_hbo = interp1d(a[:, 0], a[:, 1], kind="linear") interp_hb = interp1d(a[:, 0], a[:, 2], kind="linear") - ext_coef = np.array( - [ - [interp_hbo(freqs[0]), interp_hb(freqs[0])], - [interp_hbo(freqs[1]), interp_hb(freqs[1])], - ] - ) - abs_coef = ext_coef * 0.2303 + # Build coefficient matrix for all wavelengths + # Shape: (n_wavelengths, 2) where columns are [HbO2, Hb] + ext_coef = np.zeros((len(freqs), 2)) + for i, freq in enumerate(freqs): + ext_coef[i, 0] = interp_hbo(freq) # HbO2 + ext_coef[i, 1] = interp_hb(freq) # Hb + abs_coef = ext_coef * 0.2303 return abs_coef diff --git a/mne/preprocessing/nirs/_scalp_coupling_index.py b/mne/preprocessing/nirs/_scalp_coupling_index.py index 5a82664dfd1..dec933a1014 100644 --- a/mne/preprocessing/nirs/_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/_scalp_coupling_index.py @@ -56,14 +56,51 @@ def scalp_coupling_index( verbose=verbose, ).get_data() + # Determine number of wavelengths per source-detector pair + ch_wavelengths = [c["loc"][9] for c in raw.info["chs"]] + n_wavelengths = len(set(ch_wavelengths)) + + # freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) + # n_wavelengths = len(set(unique_freqs)) + sci = np.zeros(picks.shape) - for ii in range(0, len(picks), 2): - with np.errstate(invalid="ignore"): - c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1] - if not np.isfinite(c): # someone had std=0 - c = 0 - sci[ii] = c - sci[ii + 1] = c + + if n_wavelengths == 2: + # Use pairwise correlation for 2 wavelengths (backward compatibility) + for ii in range(0, len(picks), 2): + with np.errstate(invalid="ignore"): + c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1] + if not np.isfinite(c): # someone had std=0 + c = 0 + sci[ii] = c + sci[ii + 1] = c + else: + # For multiple wavelengths: calculate all pairwise correlations within each group + # and use the minimum as the quality metric + + # Group picks by number of wavelengths + # Drops last incomplete group, but we're assuming valid data + pick_iter = iter(picks) + pick_groups = zip(*[pick_iter] * n_wavelengths) + + for group_picks in pick_groups: + group_data = filtered_data[group_picks] + + # Calculate pairwise correlations within the group + pair_indices = np.triu_indices(len(group_picks), k=1) + correlations = np.zeros(pair_indices[0].shape[0]) + + for n, (ii, jj) in enumerate(zip(*pair_indices)): + with np.errstate(invalid="ignore"): + c = np.corrcoef(group_data[ii], group_data[jj])[0][1] + correlations[n] = c if np.isfinite(c) else 0 + + # Use minimum correlation as the quality metric (most conservative) + group_sci = correlations.min() + + # Assign the same SCI value to all channels in the group + sci[group_picks] = group_sci + sci[zero_mask] = 0 sci = sci[np.argsort(picks)] # restore original order return sci diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index 94c7c78468c..b9dd8252c0c 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from calendar import c import re import numpy as np @@ -104,7 +105,7 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # All chromophore fNIRS data picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True) - if (len(picks_wave) > 0) & (len(picks_chroma) > 0): + if (len(picks_wave) > 0) and (len(picks_chroma) > 0): picks = _throw_or_return_empty( "MNE does not support a combination of amplitude, optical " "density, and haemoglobin data in the same raw structure.", @@ -122,14 +123,14 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr picks = picks_chroma pair_vals = np.array(pair_vals) - if pair_vals.shape != (2,): + if pair_vals.shape[0] < 2: raise ValueError( - f"Exactly two {error_word} must exist in info, got {list(pair_vals)}" + f"At least two {error_word} must exist in info, got {list(pair_vals)}" ) # In principle we do not need to require that these be sorted -- # all we need to do is change our sorted() below to make use of a # pair_vals.index(...) in a sort key -- but in practice we always want - # (hbo, hbr) or (lower_freq, upper_freq) pairings, both of which will + # (hbo, hbr) or (lowest_freq, higher_freq, ...) pairings, both of which will # work with a naive string sort, so let's just enforce sorted-ness here is_str = pair_vals.dtype.kind == "U" pair_vals = list(pair_vals) @@ -145,16 +146,23 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr f"got {pair_vals} instead" ) - if len(picks) % 2 != 0: + # Check that the total number of channels is divisible by the number of pair values + # (e.g., for 2 wavelengths, we need even number of channels) + if len(picks) % len(pair_vals) !=0: picks = _throw_or_return_empty( - "NIRS channels not ordered correctly. An even number of NIRS " - f"channels is required. {len(info.ch_names)} channels were" - f"provided", + f"NIRS channels not ordered correctly. The number of channels " + f"must be a multiple of {len(pair_vals)} values, but " + f"{len(picks)} channels were provided.", throw_errors, ) - # Ensure wavelength info exists for waveform data all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] + if len(pair_vals) != len(set(all_freqs)): + picks = _throw_or_return_empty( + f"The {error_word} in info must match the number of wavelengths, " + f"but the data contains {len(set(all_freqs))} wavelengths instead.", + throw_errors, + ) if np.any(np.isnan(all_freqs)): picks = _throw_or_return_empty( f"NIRS channels is missing wavelength information in the " @@ -189,40 +197,42 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # Reorder to be paired (naive sort okay here given validation above) picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])] - # Validate our paired ordering - for ii, jj in zip(picks[::2], picks[1::2]): - ch1_name = info["chs"][ii]["ch_name"] - ch2_name = info["chs"][jj]["ch_name"] - ch1_re = use_RE.match(ch1_name) - ch2_re = use_RE.match(ch2_name) - ch1_S, ch1_D, ch1_value = ch1_re.groups()[:3] - ch2_S, ch2_D, ch2_value = ch2_re.groups()[:3] - if len(picks_wave): - ch1_value, ch2_value = float(ch1_value), float(ch2_value) - if ( - (ch1_S != ch2_S) - or (ch1_D != ch2_D) - or (ch1_value != pair_vals[0]) - or (ch2_value != pair_vals[1]) - ): + # Validate channel grouping (same source-detector pairs, all pair_vals match) + for i in range(0, len(picks), len(pair_vals)): + # Extract a group of channels (e.g., all wavelengths for one S-D pair) + group_picks = picks[i:i + len(pair_vals)] + + # Parse channel names using regex to extract source, detector, and value info + group_info = [(use_RE.match(info["ch_names"][pick]).groups() or (pick, 0, 0)) for pick in group_picks] + + # Separate the parsed components: source IDs, detector IDs, and values (freq/chromophore) + s_group, d_group, val_group = zip(*group_info) + + # For wavelength data, convert string frequencies to float for comparison + if len(picks_wave) > 0: + val_group = [float(v) for v in val_group] + + # Verify that all channels in this group have the same source-detector pair + # and that the values match the expected pair_vals sequence + if not (len(set(s_group))==1 and len(set(d_group))==1 and val_group == pair_vals): picks = _throw_or_return_empty( "NIRS channels not ordered correctly. Channels must be " - "ordered as source detector pairs with alternating" - f" {error_word} {pair_vals[0]} & {pair_vals[1]}, but got " - f"S{ch1_S}_D{ch1_D} pair " - f"{repr(ch1_name)} and {repr(ch2_name)}", + "grouped by source-detector pairs with alternating {error_word} " + f"values {pair_vals}, but got mismatching names {[info['ch_names'][pick] for pick in group_picks]}.", throw_errors, ) break if check_bads: - for ii, jj in zip(picks[::2], picks[1::2]): - want = [info.ch_names[ii], info.ch_names[jj]] + for i in range (0, len(picks), len(pair_vals)): + group_picks = picks[i:i + len(pair_vals)] + + want = [info.ch_names[pick] for pick in group_picks] got = list(set(info["bads"]).intersection(want)) - if len(got) == 1: + if 0 < len(got) < len(want): raise RuntimeError( - f"NIRS bad labelling is not consistent, found {got} but " - f"needed {want}" + "NIRS bad labelling is not consistent. " + f"Found {got} but needed {want}. " ) return picks @@ -276,14 +286,29 @@ def _fnirs_spread_bads(info): # as bad and spread the bad marking to all components of the optode pair. picks = _validate_nirs_info(info, check_bads=False) new_bads = set(info["bads"]) - for ii, jj in zip(picks[::2], picks[1::2]): - ch1_name, ch2_name = info.ch_names[ii], info.ch_names[jj] - if ch1_name in new_bads: - new_bads.add(ch2_name) - elif ch2_name in new_bads: - new_bads.add(ch1_name) - info["bads"] = sorted(new_bads) + # Extract SD pair groups from channel names + # E.g. all channels belonging to S1D1, S1D2, etc. + # Assumes valid channels (naming convention and number) + ch_names = [info.ch_names[i] for i in picks] + match = re.compile(r'^(S\d+_D\d+) ') + + # Create dict with keys corresponding to SD pairs + # Defaultdict would require another import + sd_groups = {} + for ch_name in ch_names: + sd_pair = match.match(ch_name).group(1) + if sd_pair not in sd_groups: + sd_groups[sd_pair] = [ch_name] + else: + sd_groups[sd_pair].append(ch_name) + + # Spread bad labeling across SD pairs + for channels in sd_groups.values(): + if any(channel in new_bads for channel in channels): + new_bads.update(channels) + + info["bads"] = sorted(new_bads) return info From cf34d754d2dd756157f7711291928a2068033e59 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 28 Aug 2025 16:49:46 +0900 Subject: [PATCH 02/30] fixed minor issue in _scalp_couplying_index: correlation results were in np.zeros so there was need to write 0 to it when correlation was infinite. --- mne/preprocessing/nirs/_scalp_coupling_index.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mne/preprocessing/nirs/_scalp_coupling_index.py b/mne/preprocessing/nirs/_scalp_coupling_index.py index dec933a1014..e829ed94500 100644 --- a/mne/preprocessing/nirs/_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/_scalp_coupling_index.py @@ -93,9 +93,10 @@ def scalp_coupling_index( for n, (ii, jj) in enumerate(zip(*pair_indices)): with np.errstate(invalid="ignore"): c = np.corrcoef(group_data[ii], group_data[jj])[0][1] - correlations[n] = c if np.isfinite(c) else 0 + if np.isfinite(c): + correlations[n] = c - # Use minimum correlation as the quality metric (most conservative) + # Use minimum correlation as the quality metric group_sci = correlations.min() # Assign the same SCI value to all channels in the group From f55cbc1979f3cd0bb33482055c322c3525ff4bef Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:38:38 +0900 Subject: [PATCH 03/30] removed rogue import calendar statement sneaked in by AI --- mne/preprocessing/nirs/nirs.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index b9dd8252c0c..d95901df770 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -2,7 +2,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from calendar import c import re import numpy as np @@ -148,7 +147,7 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # Check that the total number of channels is divisible by the number of pair values # (e.g., for 2 wavelengths, we need even number of channels) - if len(picks) % len(pair_vals) !=0: + if len(picks) % len(pair_vals) != 0: picks = _throw_or_return_empty( f"NIRS channels not ordered correctly. The number of channels " f"must be a multiple of {len(pair_vals)} values, but " @@ -200,10 +199,13 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # Validate channel grouping (same source-detector pairs, all pair_vals match) for i in range(0, len(picks), len(pair_vals)): # Extract a group of channels (e.g., all wavelengths for one S-D pair) - group_picks = picks[i:i + len(pair_vals)] + group_picks = picks[i : i + len(pair_vals)] # Parse channel names using regex to extract source, detector, and value info - group_info = [(use_RE.match(info["ch_names"][pick]).groups() or (pick, 0, 0)) for pick in group_picks] + group_info = [ + (use_RE.match(info["ch_names"][pick]).groups() or (pick, 0, 0)) + for pick in group_picks + ] # Separate the parsed components: source IDs, detector IDs, and values (freq/chromophore) s_group, d_group, val_group = zip(*group_info) @@ -214,7 +216,9 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # Verify that all channels in this group have the same source-detector pair # and that the values match the expected pair_vals sequence - if not (len(set(s_group))==1 and len(set(d_group))==1 and val_group == pair_vals): + if not ( + len(set(s_group)) == 1 and len(set(d_group)) == 1 and val_group == pair_vals + ): picks = _throw_or_return_empty( "NIRS channels not ordered correctly. Channels must be " "grouped by source-detector pairs with alternating {error_word} " @@ -224,8 +228,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr break if check_bads: - for i in range (0, len(picks), len(pair_vals)): - group_picks = picks[i:i + len(pair_vals)] + for i in range(0, len(picks), len(pair_vals)): + group_picks = picks[i : i + len(pair_vals)] want = [info.ch_names[pick] for pick in group_picks] got = list(set(info["bads"]).intersection(want)) @@ -291,7 +295,7 @@ def _fnirs_spread_bads(info): # E.g. all channels belonging to S1D1, S1D2, etc. # Assumes valid channels (naming convention and number) ch_names = [info.ch_names[i] for i in picks] - match = re.compile(r'^(S\d+_D\d+) ') + match = re.compile(r"^(S\d+_D\d+) ") # Create dict with keys corresponding to SD pairs # Defaultdict would require another import From 7489f10baefb115d2ff6ea1dc8f788e5dc2beaff Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:45:48 +0900 Subject: [PATCH 04/30] reverted doc string on _load_absorption, applied minor changes to comments to show correct return format --- mne/preprocessing/nirs/_beer_lambert_law.py | 35 ++++++++------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index ec5d20842fd..7082a2ee8cd 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -46,8 +46,12 @@ def beer_lambert_law(raw, ppf=6.0): # PPF validation for multiple wavelengths if ppf.ndim == 0: # single float - ppf = np.full((n_wavelengths, 1), ppf) # same PPF for all wavelengths, shape (n_wavelengths, 1) - elif ppf.ndim == 1 and len(ppf) == n_wavelengths: # separate ppf for each wavelength + ppf = np.full( + (n_wavelengths, 1), ppf + ) # same PPF for all wavelengths, shape (n_wavelengths, 1) + elif ( + ppf.ndim == 1 and len(ppf) == n_wavelengths + ): # separate ppf for each wavelength ppf = ppf[:, np.newaxis] # shape (n_wavelengths, 1) else: raise ValueError( @@ -118,31 +122,18 @@ def beer_lambert_law(raw, ppf=6.0): def _load_absorption(freqs): - """Load molar extinction coefficients - - Parameters - ---------- - freqs : array-like - Array of wavelengths (frequencies) in nm. - - Returns - ------- - abs_coef : array, shape (n_wavelengths, 2) - Absorption coefficients for HbO2 and Hb at each wavelength. - abs_coef[:, 0] contains HbO2 coefficients - abs_coef[:, 1] contains Hb coefficients - - E.g. [[HbO2(freq1)], [Hb(freq1)], - [HbO2(freq2)], [Hb(freq2)], - ..., - [HbO2(freqN)], [Hb(freqN)]] - """ + """Load molar extinction coefficients""" # Data from https://omlc.org/spectra/hemoglobin/summary.html # The text was copied to a text file. The text before and # after the table was deleted. The the following was run in # matlab # extinct_coef=importdata('extinction_coef.txt') # save('extinction_coef.mat', 'extinct_coef') + # + # Returns data as [[HbO2(freq1), Hb(freq1)], + # [HbO2(freq2), Hb(freq2)], + # ..., + # [HbO2(freqN), Hb(freqN)]] extinction_fname = op.join( op.dirname(__file__), "..", "..", "data", "extinction_coef.mat" ) @@ -156,7 +147,7 @@ def _load_absorption(freqs): ext_coef = np.zeros((len(freqs), 2)) for i, freq in enumerate(freqs): ext_coef[i, 0] = interp_hbo(freq) # HbO2 - ext_coef[i, 1] = interp_hb(freq) # Hb + ext_coef[i, 1] = interp_hb(freq) # Hb abs_coef = ext_coef * 0.2303 return abs_coef From def3e1060a37bd0cf26c42ba6ba8d7b32b1bffde Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:56:17 +0900 Subject: [PATCH 05/30] put inline comments above code to prevent autoformat from spreading code over several lines --- mne/preprocessing/nirs/_beer_lambert_law.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index 7082a2ee8cd..c3091585ac1 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -46,12 +46,10 @@ def beer_lambert_law(raw, ppf=6.0): # PPF validation for multiple wavelengths if ppf.ndim == 0: # single float - ppf = np.full( - (n_wavelengths, 1), ppf - ) # same PPF for all wavelengths, shape (n_wavelengths, 1) - elif ( - ppf.ndim == 1 and len(ppf) == n_wavelengths - ): # separate ppf for each wavelength + # same PPF for all wavelengths, shape (n_wavelengths, 1) + ppf = np.full((n_wavelengths, 1), ppf) + elif ppf.ndim == 1 and len(ppf) == n_wavelengths: + # separate ppf for each wavelength ppf = ppf[:, np.newaxis] # shape (n_wavelengths, 1) else: raise ValueError( From 6accb5e0c3c4bbc20bfac8e1f028f3ce5e0d38d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Sep 2025 00:17:52 +0000 Subject: [PATCH 06/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/nirs/_beer_lambert_law.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index c3091585ac1..f26ea90a0d3 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -82,7 +82,6 @@ def beer_lambert_law(raw, ppf=6.0): # Iterate over channel groups ([Si_Di all wavelengths, Sj_Dj all wavelengths, ...]) pick_groups = zip(*[iter(picks)] * n_wavelengths) for group_picks in pick_groups: - # Calculate Δc based on the system: ΔOD = E * L * PPF * Δc # where E is (n_wavelengths, 2), Δc is (2, n_timepoints) # using pseudo-inverse From 9623db111bbf4fa7aa58339245918267923eb392 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 20 Nov 2025 00:58:09 +0900 Subject: [PATCH 07/30] Applied recommended code changes & bug fix in _scalp_coupling_index iteration --- mne/preprocessing/nirs/_beer_lambert_law.py | 2 +- .../nirs/_scalp_coupling_index.py | 53 +++++++------------ mne/preprocessing/nirs/nirs.py | 11 ++-- 3 files changed, 25 insertions(+), 41 deletions(-) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index f26ea90a0d3..aceaaefcadf 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -119,7 +119,7 @@ def beer_lambert_law(raw, ppf=6.0): def _load_absorption(freqs): - """Load molar extinction coefficients""" + """Load molar extinction coefficients.""" # Data from https://omlc.org/spectra/hemoglobin/summary.html # The text was copied to a text file. The text before and # after the table was deleted. The the following was run in diff --git a/mne/preprocessing/nirs/_scalp_coupling_index.py b/mne/preprocessing/nirs/_scalp_coupling_index.py index e829ed94500..fb6521e0aa8 100644 --- a/mne/preprocessing/nirs/_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/_scalp_coupling_index.py @@ -65,42 +65,25 @@ def scalp_coupling_index( sci = np.zeros(picks.shape) - if n_wavelengths == 2: - # Use pairwise correlation for 2 wavelengths (backward compatibility) - for ii in range(0, len(picks), 2): + # Calculate all pairwise correlations within each group and use the minimum as SCI + pair_indices = np.triu_indices(n_wavelengths, k=1) + for gg in range(0, len(picks), n_wavelengths): + group_data = filtered_data[gg : gg + n_wavelengths] + + # Calculate pairwise correlations within the group + correlations = np.zeros(pair_indices[0].shape[0]) + + for n, (ii, jj) in enumerate(zip(*pair_indices)): with np.errstate(invalid="ignore"): - c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1] - if not np.isfinite(c): # someone had std=0 - c = 0 - sci[ii] = c - sci[ii + 1] = c - else: - # For multiple wavelengths: calculate all pairwise correlations within each group - # and use the minimum as the quality metric - - # Group picks by number of wavelengths - # Drops last incomplete group, but we're assuming valid data - pick_iter = iter(picks) - pick_groups = zip(*[pick_iter] * n_wavelengths) - - for group_picks in pick_groups: - group_data = filtered_data[group_picks] - - # Calculate pairwise correlations within the group - pair_indices = np.triu_indices(len(group_picks), k=1) - correlations = np.zeros(pair_indices[0].shape[0]) - - for n, (ii, jj) in enumerate(zip(*pair_indices)): - with np.errstate(invalid="ignore"): - c = np.corrcoef(group_data[ii], group_data[jj])[0][1] - if np.isfinite(c): - correlations[n] = c - - # Use minimum correlation as the quality metric - group_sci = correlations.min() - - # Assign the same SCI value to all channels in the group - sci[group_picks] = group_sci + c = np.corrcoef(group_data[ii], group_data[jj])[0][1] + if np.isfinite(c): + correlations[n] = c + + # Use minimum correlation as SCI + group_sci = correlations.min() + + # Assign the same SCI value to all channels in the group + sci[gg : gg + n_wavelengths] = group_sci sci[zero_mask] = 0 sci = sci[np.argsort(picks)] # restore original order diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index d95901df770..f505c4503d1 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -149,11 +149,12 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # (e.g., for 2 wavelengths, we need even number of channels) if len(picks) % len(pair_vals) != 0: picks = _throw_or_return_empty( - f"NIRS channels not ordered correctly. The number of channels " + "NIRS channels not ordered correctly. The number of channels " f"must be a multiple of {len(pair_vals)} values, but " f"{len(picks)} channels were provided.", throw_errors, ) + # Ensure wavelength info exists for waveform data all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] if len(pair_vals) != len(set(all_freqs)): @@ -197,9 +198,9 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])] # Validate channel grouping (same source-detector pairs, all pair_vals match) - for i in range(0, len(picks), len(pair_vals)): + for ii in range(0, len(picks), len(pair_vals)): # Extract a group of channels (e.g., all wavelengths for one S-D pair) - group_picks = picks[i : i + len(pair_vals)] + group_picks = picks[ii : ii + len(pair_vals)] # Parse channel names using regex to extract source, detector, and value info group_info = [ @@ -228,8 +229,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr break if check_bads: - for i in range(0, len(picks), len(pair_vals)): - group_picks = picks[i : i + len(pair_vals)] + for ii in range(0, len(picks), len(pair_vals)): + group_picks = picks[ii : ii + len(pair_vals)] want = [info.ch_names[pick] for pick in group_picks] got = list(set(info["bads"]).intersection(want)) From 6ea68170d7b9f0b8c1e496cff8bd964a78dce9f4 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 20 Nov 2025 01:03:49 +0900 Subject: [PATCH 08/30] line length fixes for pre-commit in mne/preprocessing/nirs/nirs.py --- mne/preprocessing/nirs/nirs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index f505c4503d1..dc3539d95a0 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -208,7 +208,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr for pick in group_picks ] - # Separate the parsed components: source IDs, detector IDs, and values (freq/chromophore) + # Separate the parsed components: + # source IDs, detector IDs, and values (freq/chromophore) s_group, d_group, val_group = zip(*group_info) # For wavelength data, convert string frequencies to float for comparison @@ -223,7 +224,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr picks = _throw_or_return_empty( "NIRS channels not ordered correctly. Channels must be " "grouped by source-detector pairs with alternating {error_word} " - f"values {pair_vals}, but got mismatching names {[info['ch_names'][pick] for pick in group_picks]}.", + f"values {pair_vals}, but got mismatching names " + f"{[info['ch_names'][pick] for pick in group_picks]}.", throw_errors, ) break From 1323e6629ebc3be100e74994bf106dd3fc6955c8 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 20 Nov 2025 12:29:23 +0900 Subject: [PATCH 09/30] tests in mne/preprocessing now run, pre-commit also happy locally --- mne/preprocessing/nirs/_beer_lambert_law.py | 4 +- mne/preprocessing/nirs/nirs.py | 44 ++++++++++--------- .../nirs/tests/test_beer_lambert_law.py | 26 +++++++---- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index aceaaefcadf..1bd9b17599f 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -80,8 +80,8 @@ def beer_lambert_law(raw, ppf=6.0): channels_to_drop_all = [] # Accumulate all channels to drop # Iterate over channel groups ([Si_Di all wavelengths, Sj_Dj all wavelengths, ...]) - pick_groups = zip(*[iter(picks)] * n_wavelengths) - for group_picks in pick_groups: + for ii in range(0, len(picks), n_wavelengths): + group_picks = picks[ii : ii + n_wavelengths] # Calculate Δc based on the system: ΔOD = E * L * PPF * Δc # where E is (n_wavelengths, 2), Δc is (2, n_timepoints) # using pseudo-inverse diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index dc3539d95a0..224e6f77c50 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -132,9 +132,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # (hbo, hbr) or (lowest_freq, higher_freq, ...) pairings, both of which will # work with a naive string sort, so let's just enforce sorted-ness here is_str = pair_vals.dtype.kind == "U" - pair_vals = list(pair_vals) if is_str: - if pair_vals != ["hbo", "hbr"]: + if pair_vals.tolist() != ["hbo", "hbr"]: raise ValueError( f'The {error_word} in info must be ["hbo", "hbr"], but got ' f"{pair_vals} instead" @@ -156,19 +155,22 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr ) # Ensure wavelength info exists for waveform data - all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] - if len(pair_vals) != len(set(all_freqs)): - picks = _throw_or_return_empty( - f"The {error_word} in info must match the number of wavelengths, " - f"but the data contains {len(set(all_freqs))} wavelengths instead.", - throw_errors, - ) - if np.any(np.isnan(all_freqs)): - picks = _throw_or_return_empty( - f"NIRS channels is missing wavelength information in the " - f'info["chs"] structure. The encoded wavelengths are {all_freqs}.', - throw_errors, - ) + if len(picks_wave) > 0: + all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] + # test for nan values first as those mess up the output of set() + if np.any(np.isnan(all_freqs)): + picks = _throw_or_return_empty( + f"NIRS channels is missing wavelength information in the " + f'info["chs"] structure. The encoded wavelengths are {all_freqs}.', + throw_errors, + ) + # test if the info structure has the same number of freqs as pair_vals + if len(pair_vals) != len(set(all_freqs)): + picks = _throw_or_return_empty( + f"The {error_word} in info must match the number of wavelengths, " + f"but the data contains {len(set(all_freqs))} wavelengths instead.", + throw_errors, + ) # Validate the channel naming scheme for pick in picks: @@ -182,8 +184,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr ) break value = ch_name_info.groups()[2] - if len(picks_wave): - value = value + if len(picks_wave) > 0: + pass else: # picks_chroma if value not in ["hbo", "hbr"]: picks = _throw_or_return_empty( @@ -214,16 +216,18 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # For wavelength data, convert string frequencies to float for comparison if len(picks_wave) > 0: - val_group = [float(v) for v in val_group] + val_group = np.array([float(v) for v in val_group]) # Verify that all channels in this group have the same source-detector pair # and that the values match the expected pair_vals sequence if not ( - len(set(s_group)) == 1 and len(set(d_group)) == 1 and val_group == pair_vals + len(set(s_group)) == 1 + and len(set(d_group)) == 1 + and np.array_equal(val_group, pair_vals) ): picks = _throw_or_return_empty( "NIRS channels not ordered correctly. Channels must be " - "grouped by source-detector pairs with alternating {error_word} " + f"grouped by source-detector pairs with alternating {error_word} " f"values {pair_vals}, but got mismatching names " f"{[info['ch_names'][pick] for pick in group_picks]}.", throw_errors, diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index 5768ff038ab..ed5a12bbdba 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -52,20 +52,28 @@ def test_beer_lambert_unordered_errors(): raw = read_raw_nirx(fname_nirx_15_0) raw_od = optical_density(raw) raw_od.pick([0, 1, 2]) - with pytest.raises(ValueError, match="ordered"): + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): beer_lambert_law(raw_od) # Test that an error is thrown if channel naming frequency doesn't match # what is stored in loc[9], which should hold the light frequency too. + # Introduce 2 new frequencies to make it 4 in total vs 2 stored in loc[9]. + # This way the bad data will have 20 channels and 4 wavelengths, so as not + # to get caught by the check for divisibility (channel % wavelength == 0). raw_od = optical_density(raw) - ch_name = raw.ch_names[0] - assert ch_name == "S1_D1 760" - idx = raw_od.ch_names.index(ch_name) - assert idx == 0 - raw_od.info["chs"][idx]["loc"][9] = 770 - raw_od.rename_channels({ch_name: ch_name.replace("760", "770")}) - assert raw_od.ch_names[0] == "S1_D1 770" - with pytest.raises(ValueError, match="Exactly two frequencies"): + assert raw.ch_names[0] == "S1_D1 760" and raw.ch_names[1] == "S1_D1 850" + assert ( + raw_od.ch_names.index(raw.ch_names[0]) == 0 + and raw_od.ch_names.index(raw.ch_names[1]) == 1 + ) + raw_od.rename_channels( + { + raw.ch_names[0]: raw.ch_names[0].replace("760", "770"), + raw.ch_names[1]: raw.ch_names[1].replace("850", "840"), + } + ) + assert raw_od.ch_names[0] == "S1_D1 770" and raw_od.ch_names[1] == "S1_D1 840" + with pytest.raises(ValueError, match="must match the number of wavelengths"): beer_lambert_law(raw_od) From ba5d48b843565ddd0bd336b4d786d1e004d633e3 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Thu, 27 Nov 2025 22:48:43 +0900 Subject: [PATCH 10/30] FIX: Use nominal wavelengths in BLL and SCI calculations FIX: mne/io/hitachi tests with actual wavelength data now succeed * Beer-Lambert Law (BLL) and Scalp Coupling Index (SCI) calculations used the info structure to determine the number of wavelengths, but that lead to errors as the info structure can contain arbitrary data (e.g. different wavelengths for each channel). * BLL now uses the nominal wavelengths for the BLL calculation for all channels. Previously, it used the actual wavelengths of the first channel pair for all channels, which was incorrect if the channels have different actual wavelengths. * In the future, there may be an option in BLL to use the actual freq values for each channel, to improve calculation accuracy. * The Hitachi tests have data with actual wavelengths for each channel stored in the info structure. This caused issues with counting the number of wavelengths, which caused tests to fail. --- mne/preprocessing/nirs/_beer_lambert_law.py | 18 +++++++++++++++--- .../nirs/_scalp_coupling_index.py | 8 +++++--- mne/preprocessing/nirs/nirs.py | 9 ++------- .../nirs/tests/test_beer_lambert_law.py | 2 +- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index 1bd9b17599f..cace8b4242a 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -11,7 +11,7 @@ from ..._fiff.constants import FIFF from ...io import BaseRaw from ...utils import _validate_type, pinv, warn -from ..nirs import _validate_nirs_info, source_detector_distances +from ..nirs import _channel_frequencies, _validate_nirs_info, source_detector_distances def beer_lambert_law(raw, ppf=6.0): @@ -37,8 +37,20 @@ def beer_lambert_law(raw, ppf=6.0): _validate_type(ppf, ("numeric", "array-like"), "ppf") ppf = np.array(ppf, float) picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert") - # This is the one place we *really* need the actual/accurate frequencies - freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) + + # Use nominal channel frequencies + # + # Notes on implementation: + # 1. Frequencies are calculated the same way as in nirs._validate_nirs_info(). + # 2. Wavelength values in the info structure may contain actual frequencies, + # which may be used for more accurate calculation in the future. + # 3. nirs._channel_frequencies uses both cw_amplitude and OD data to determine + # frequencies, whereas we only need those from OD here. Is there any chance + # that they're different? + # 4. If actual frequencies were used, using np.unique() like below will lead to + # errors. Instead, absorption coefficients will need to be calculated for + # each individual frequency. + freqs = _channel_frequencies(raw.info) # Get unique wavelengths and determine number of wavelengths unique_freqs = np.unique(freqs) diff --git a/mne/preprocessing/nirs/_scalp_coupling_index.py b/mne/preprocessing/nirs/_scalp_coupling_index.py index fb6521e0aa8..366db2633e2 100644 --- a/mne/preprocessing/nirs/_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/_scalp_coupling_index.py @@ -6,7 +6,7 @@ from ...io import BaseRaw from ...utils import _validate_type, verbose -from ..nirs import _validate_nirs_info +from ..nirs import _channel_frequencies, _validate_nirs_info @verbose @@ -57,8 +57,9 @@ def scalp_coupling_index( ).get_data() # Determine number of wavelengths per source-detector pair - ch_wavelengths = [c["loc"][9] for c in raw.info["chs"]] - n_wavelengths = len(set(ch_wavelengths)) + # We use nominal wavelengths as the info structure may contain arbitrary data. + freqs = _channel_frequencies(raw.info) + n_wavelengths = len(np.unique(freqs)) # freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) # n_wavelengths = len(set(unique_freqs)) @@ -67,6 +68,7 @@ def scalp_coupling_index( # Calculate all pairwise correlations within each group and use the minimum as SCI pair_indices = np.triu_indices(n_wavelengths, k=1) + for gg in range(0, len(picks), n_wavelengths): group_data = filtered_data[gg : gg + n_wavelengths] diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index 224e6f77c50..afd4fbaef8d 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -155,6 +155,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr ) # Ensure wavelength info exists for waveform data + # Note: currently, the only requirement for the wavelength field in info is + # that it cannot be NaN. It depends on the data readers what is stored in it. if len(picks_wave) > 0: all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] # test for nan values first as those mess up the output of set() @@ -164,13 +166,6 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr f'info["chs"] structure. The encoded wavelengths are {all_freqs}.', throw_errors, ) - # test if the info structure has the same number of freqs as pair_vals - if len(pair_vals) != len(set(all_freqs)): - picks = _throw_or_return_empty( - f"The {error_word} in info must match the number of wavelengths, " - f"but the data contains {len(set(all_freqs))} wavelengths instead.", - throw_errors, - ) # Validate the channel naming scheme for pick in picks: diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index ed5a12bbdba..d80cd5fd7a7 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -73,7 +73,7 @@ def test_beer_lambert_unordered_errors(): } ) assert raw_od.ch_names[0] == "S1_D1 770" and raw_od.ch_names[1] == "S1_D1 840" - with pytest.raises(ValueError, match="must match the number of wavelengths"): + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): beer_lambert_law(raw_od) From 566f564e0dbfc756fa6c576766d87ecec779e083 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Fri, 28 Nov 2025 01:33:09 +0900 Subject: [PATCH 11/30] Removed unnecessary comments --- mne/preprocessing/nirs/_beer_lambert_law.py | 3 +-- mne/preprocessing/nirs/_scalp_coupling_index.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index cace8b4242a..0c1482ef3b8 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -98,7 +98,7 @@ def beer_lambert_law(raw, ppf=6.0): # where E is (n_wavelengths, 2), Δc is (2, n_timepoints) # using pseudo-inverse EL = abs_coef * distances[group_picks[0]] * ppf - iEL = pinv(EL) # Pseudo-inverse for numerical stability + iEL = pinv(EL) conc_data = iEL @ raw._data[group_picks] * 1e-3 # Replace the first two channels with HbO and HbR @@ -119,7 +119,6 @@ def beer_lambert_law(raw, ppf=6.0): channels_to_drop_all.extend(channel_names_to_drop) # Drop all accumulated extra wavelength channels after processing all groups - # This preserves channel indexing during the loop iterations if channels_to_drop_all: raw.drop_channels(channels_to_drop_all) diff --git a/mne/preprocessing/nirs/_scalp_coupling_index.py b/mne/preprocessing/nirs/_scalp_coupling_index.py index 366db2633e2..062af0be8aa 100644 --- a/mne/preprocessing/nirs/_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/_scalp_coupling_index.py @@ -61,9 +61,6 @@ def scalp_coupling_index( freqs = _channel_frequencies(raw.info) n_wavelengths = len(np.unique(freqs)) - # freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) - # n_wavelengths = len(set(unique_freqs)) - sci = np.zeros(picks.shape) # Calculate all pairwise correlations within each group and use the minimum as SCI From e3ea50a48712ce5d8c7d065140446cfb7f1d513b Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:23:18 +0900 Subject: [PATCH 12/30] Add new author name to doc/changes/names.inc and summary for PR #13408 to doc/changes/dev/13408.newfeature.rst --- doc/changes/dev/13408.newfeature.rst | 2 ++ doc/changes/names.inc | 1 + 2 files changed, 3 insertions(+) create mode 100644 doc/changes/dev/13408.newfeature.rst diff --git a/doc/changes/dev/13408.newfeature.rst b/doc/changes/dev/13408.newfeature.rst new file mode 100644 index 00000000000..9ac833d7027 --- /dev/null +++ b/doc/changes/dev/13408.newfeature.rst @@ -0,0 +1,2 @@ +Add support for multi-wavelength NIRS processing to :func:`mne.preprocessing.nirs.beer_lambert_law`, :func:`mne.preprocessing.nirs.scalp_coupling_index`, :mod:`mne.preprocessing.nirs.nirs`, and SNIRF reader :class:`mne.io.snirf._snirf.RawSNIRF`, by :newcontrib:`Tamas Fehervari`. +` diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 77e665ec6ed..ff6526aad80 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -319,6 +319,7 @@ .. _Sébastien Marti: https://www.researchgate.net/profile/Sebastien-Marti .. _T. Wang: https://github.com/twang5 .. _Tal Linzen: https://tallinzen.net/ +.. _Tamas Fehervari: https://github.com/zEdS15B3GCwq .. _Teon Brooks: https://github.com/teonbrooks .. _Tharupahan Jayawardana: https://github.com/tharu-jwd .. _Thomas Binns: https://github.com/tsbinns From a879f624c73cfb2d84aa68901b8c89a092d3fe7a Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 12 Dec 2025 16:21:23 +0000 Subject: [PATCH 13/30] Validate `meas_date` input type (#13528) --- doc/changes/dev/13528.bugfix.rst | 1 + mne/_fiff/meas_info.py | 4 ++++ mne/_fiff/tests/test_meas_info.py | 7 +++++++ mne/export/tests/test_export.py | 2 +- 4 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 doc/changes/dev/13528.bugfix.rst diff --git a/doc/changes/dev/13528.bugfix.rst b/doc/changes/dev/13528.bugfix.rst new file mode 100644 index 00000000000..ef24b8c6e24 --- /dev/null +++ b/doc/changes/dev/13528.bugfix.rst @@ -0,0 +1 @@ +Fix bug where invalid date formats passed to :meth:`mne.Info.set_meas_date` were not caught, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 708b7135a9c..90454b9a699 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -833,6 +833,10 @@ def set_meas_date(self, meas_date): """ from ..annotations import _handle_meas_date + _validate_type( + meas_date, (datetime.datetime, "numeric", tuple, None), "meas_date" + ) + info = self if isinstance(self, Info) else self.info meas_date = _handle_meas_date(meas_date) diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 4e409d262e0..d0effacde91 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -1199,6 +1199,13 @@ def test_invalid_subject_birthday(): assert "birthday" not in raw.info["subject_info"] +def test_invalid_set_meas_date(): + """Test set_meas_date catches invalid str input.""" + info = create_info(1, 1000, "eeg") + with pytest.raises(TypeError, match=r"meas_date must be an instance of"): + info.set_meas_date("2025-01-01 00:00:00.000000") + + @pytest.mark.slowtest @pytest.mark.parametrize( "fname", diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index ac9551252c5..4651123d499 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -198,7 +198,7 @@ def _create_raw_for_edf_tests(stim_channel_index=None): def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) - raw.info.set_meas_date("2023-09-04 14:53:09.000") + raw.info.set_meas_date(datetime(2023, 9, 4, 14, 53, 9, tzinfo=timezone.utc)) raw.set_annotations(Annotations(onset=[1], duration=[0], description=["test"])) # include subject info and measurement date From cfff52878ea3fa0220ea605ee33153fcf021a8d2 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 12 Dec 2025 16:22:17 +0000 Subject: [PATCH 14/30] Convert old `Spectrum` and `TFR` birthday info format on read (#13526) --- doc/changes/dev/13526.bugfix.rst | 1 + mne/time_frequency/spectrum.py | 8 +++++-- mne/time_frequency/tests/test_spectrum.py | 27 ++++++++++++++++++----- mne/time_frequency/tests/test_tfr.py | 19 +++++++++++++++- mne/time_frequency/tfr.py | 6 ++--- mne/utils/spectrum.py | 11 +++++++++ 6 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 doc/changes/dev/13526.bugfix.rst diff --git a/doc/changes/dev/13526.bugfix.rst b/doc/changes/dev/13526.bugfix.rst new file mode 100644 index 00000000000..09827e7a581 --- /dev/null +++ b/doc/changes/dev/13526.bugfix.rst @@ -0,0 +1 @@ +Fix bug preventing reading of :class:`mne.time_frequency.Spectrum` and :class:`mne.time_frequency.BaseTFR` objects created in MNE<1.8 using the deprecated subject info birthday tuple format, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 0d0ce0c30c8..6591527a28b 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -46,7 +46,11 @@ check_fname, ) from ..utils.misc import _pl -from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs +from ..utils.spectrum import ( + _convert_old_birthday_format, + _get_instance_type_string, + _split_psd_kwargs, +) from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap from ..viz.utils import ( @@ -391,7 +395,7 @@ def __setstate__(self, state): self._freqs = state["freqs"] self._dims = state["dims"] self._sfreq = state["sfreq"] - self.info = Info(**state["info"]) + self.info = Info(**_convert_old_birthday_format(state["info"])) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` self._weights = state.get("weights") # objs saved before #12747 won't have diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index b1ad677352d..c3197173492 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import datetime import re from functools import partial @@ -26,7 +27,7 @@ SpectrumArray, combine_spectrum, ) -from mne.utils import _record_warnings +from mne.utils import _import_h5io_funcs, _record_warnings def test_compute_psd_errors(raw): @@ -178,6 +179,7 @@ def _get_inst(inst, request, *, evoked=None, average_tfr=None): def test_spectrum_io(inst, tmp_path, request, evoked): """Test save/load of spectrum objects.""" pytest.importorskip("h5io") + h5py = pytest.importorskip("h5py") fname = tmp_path / f"{inst}-spectrum.h5" inst = _get_inst(inst, request, evoked=evoked) if isinstance(inst, BaseEpochs): @@ -190,12 +192,25 @@ def test_spectrum_io(inst, tmp_path, request, evoked): orig.save(fname) loaded = read_spectrum(fname) assert orig == loaded + # Only check following for one type + if not isinstance(inst, BaseEpochs): + return + # Test loading with old-style birthday format + fname_subject_info = tmp_path / "subject-info.h5" + _, write_hdf5 = _import_h5io_funcs() + write_hdf5(fname_subject_info, dict(birthday=(2000, 1, 1)), title="subject_info") + with h5py.File(fname, "r+") as f: + del f["mnepython/key_info/key_subject_info"] + f["mnepython/key_info/key_subject_info"] = h5py.ExternalLink( + fname_subject_info, "subject_info" + ) + loaded = read_spectrum(fname) + assert isinstance(loaded.info["subject_info"]["birthday"], datetime.date) # Test Spectrum from EpochsSpectrum.average() can be read (gh-13521) - if isinstance(inst, BaseEpochs): - origavg = orig.average() - origavg.save(fname, overwrite=True) - loadedavg = read_spectrum(fname) - assert origavg == loadedavg + origavg = orig.average() + origavg.save(fname, overwrite=True) + loadedavg = read_spectrum(fname) + assert origavg == loadedavg def test_spectrum_copy(raw_spectrum): diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index ed6ddd6da82..fdf89a836c0 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -51,7 +51,7 @@ tfr_multitaper, write_tfrs, ) -from mne.utils import catch_logging, grand_average +from mne.utils import _import_h5io_funcs, catch_logging, grand_average from mne.utils._testing import _get_suptitle from mne.viz.utils import ( _channel_type_prettyprint, @@ -620,6 +620,7 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): """Test TFR I/O.""" pytest.importorskip("h5io") pd = pytest.importorskip("pandas") + h5py = pytest.importorskip("h5py") tfr = _get_inst(inst, request, average_tfr=average_tfr) fname = tmp_path / "temp_tfr.hdf5" @@ -679,6 +680,22 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): tfravg.save(fname, overwrite=True) tfravg_loaded = read_tfrs(fname) assert tfravg == tfravg_loaded + # test loading with old-style birthday format + fname_multi = tmp_path / "temp_multi_tfr.hdf5" + write_tfrs(fname_multi, tfr) # also check for multiple files from write_tfrs + fname_subject_info = tmp_path / "subject-info.hdf5" + _, write_hdf5 = _import_h5io_funcs() + write_hdf5(fname_subject_info, dict(birthday=(2000, 1, 1)), title="subject_info") + for this_fname in (fname, fname_multi): + with h5py.File(this_fname, "r+") as f: + if f.get("mnepython/key_info/key_subject_info"): + path = "mnepython/key_info/key_subject_info" + else: # multi-files on linux have different path to attrs + path = "mnepython/idx_0/idx_1/key_info/key_subject_info" + del f[path] + f[path] = h5py.ExternalLink(fname_subject_info, "subject_info") + tfr_loaded = read_tfrs(this_fname) + assert isinstance(tfr_loaded.info["subject_info"]["birthday"], datetime.date) # test with taper dimension and weights n_tapers = 3 # anything >= 1 should do weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index f64680845c4..f232ff30158 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -59,7 +59,7 @@ verbose, warn, ) -from ..utils.spectrum import _get_instance_type_string +from ..utils.spectrum import _convert_old_birthday_format, _get_instance_type_string from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo from ..viz.topomap import ( _add_colorbar, @@ -1433,7 +1433,7 @@ def __setstate__(self, state): self._dims = defaults["dims"] self._raw_times = np.asarray(defaults["times"], dtype=np.float64) self._baseline = defaults["baseline"] - self.info = Info(**defaults["info"]) + self.info = Info(**_convert_old_birthday_format(defaults["info"])) self._data_type = defaults["data_type"] self._decim = defaults["decim"] self.preload = True @@ -4141,7 +4141,7 @@ def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): if key != condition: continue tfr = dict(tfr) - tfr["info"] = Info(tfr["info"]) + tfr["info"] = Info(_convert_old_birthday_format(tfr["info"])) tfr["info"]._check_consistency() if "metadata" in tfr: tfr["metadata"] = _prepare_read_metadata(tfr["metadata"]) diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index 69052f21797..1efd06381c9 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -4,6 +4,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from datetime import datetime from inspect import currentframe, getargvalues, signature from ..utils import warn @@ -102,3 +103,13 @@ def _split_psd_kwargs(*, plot_fun=None, kwargs=None): for k in plot_kwargs: del kwargs[k] return kwargs, plot_kwargs + + +def _convert_old_birthday_format(info): + """Convert deprecated birthday tuple to datetime.""" + subject_info = info.get("subject_info") + if subject_info is not None: + birthday = subject_info.get("birthday") + if isinstance(birthday, tuple): + info["subject_info"]["birthday"] = datetime(*birthday) + return info From 5bcf55e657d7986e949269807b8bd2c68ceb9c71 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 12 Dec 2025 17:47:12 -0500 Subject: [PATCH 15/30] FIX: Fix bug with fitting coil order and GOF (#13525) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- doc/changes/dev/13525.bugfix.rst | 1 + doc/sphinxext/directive_formatting.py | 4 +- ...decoding_time_generalization_conditions.py | 7 +- mne/chpi.py | 78 ++++++++++++++----- mne/datasets/config.py | 4 +- mne/tests/test_chpi.py | 17 ++++ 6 files changed, 83 insertions(+), 28 deletions(-) create mode 100644 doc/changes/dev/13525.bugfix.rst diff --git a/doc/changes/dev/13525.bugfix.rst b/doc/changes/dev/13525.bugfix.rst new file mode 100644 index 00000000000..8477178380a --- /dev/null +++ b/doc/changes/dev/13525.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.chpi.refit_hpi` did not take ``gof_limit`` into account when fitting HPI order, by `Eric Larson`_ diff --git a/doc/sphinxext/directive_formatting.py b/doc/sphinxext/directive_formatting.py index a3090ab4c90..4c65f653d4a 100644 --- a/doc/sphinxext/directive_formatting.py +++ b/doc/sphinxext/directive_formatting.py @@ -60,7 +60,7 @@ def check_directive_formatting(*args): # another directive/another directive's content) if idx == 0: continue - dir_pattern = r"\.\. [a-zA-Z]+::" + dir_pattern = r"^\s*\.\. \w+::" # line might start with whitespace head_pattern = r"^[-|=|\^]+$" directive = re.search(dir_pattern, line) if directive is not None: @@ -84,5 +84,5 @@ def check_directive_formatting(*args): if bad: sphinx_logger.warning( f"{source_type} '{name}' is missing a blank line before the " - f"directive '{directive.group()}'" + f"directive '{directive.group()}' on line {idx + 1}" ) diff --git a/examples/decoding/decoding_time_generalization_conditions.py b/examples/decoding/decoding_time_generalization_conditions.py index e71112e8375..cc9e62f06cf 100644 --- a/examples/decoding/decoding_time_generalization_conditions.py +++ b/examples/decoding/decoding_time_generalization_conditions.py @@ -6,10 +6,9 @@ ========================================================================= This example runs the analysis described in :footcite:`KingDehaene2014`. It -illustrates how one can -fit a linear classifier to identify a discriminatory topography at a given time -instant and subsequently assess whether this linear model can accurately -predict all of the time samples of a second set of conditions. +illustrates how one can fit a linear classifier to identify a discriminatory +topography at a given time instant and subsequently assess whether this linear +model can accurately predict all of the time samples of a second set of conditions. """ # Authors: Jean-Rémi King # Alexandre Gramfort diff --git a/mne/chpi.py b/mne/chpi.py index 711474338c9..cc921a9843e 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -579,27 +579,37 @@ def _chpi_objective(x, coil_dev_rrs, coil_head_rrs): return d.sum() -def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs): +def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs, *, quat=None): """Fit rotation and translation (quaternion) parameters for cHPI coils.""" denom = np.linalg.norm(coil_head_rrs - np.mean(coil_head_rrs, axis=0)) denom *= denom # We could try to solve it the analytic way: # TODO someday we could choose to weight these points by their goodness # of fit somehow, see also https://github.com/mne-tools/mne-python/issues/11330 - quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0] + if quat is None: + quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0] gof = 1.0 - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom return quat, gof -def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix=""): +def _fit_coil_order_dev_head_trans( + dev_pnts, head_pnts, *, bias=True, gofs=None, gof_limit=0.98, prefix="" +): """Compute Device to Head transform allowing for permutiatons of points.""" + n_coils = len(dev_pnts) id_quat = np.zeros(6) - best_order = None + best_order = np.full(n_coils, -1, dtype=int) best_g = -999 best_quat = id_quat - for this_order in itertools.permutations(np.arange(len(head_pnts))): + assert dev_pnts.shape == head_pnts.shape == (n_coils, 3) + gofs = np.ones(n_coils) if gofs is None else gofs + use_mask = _gof_use_mask(gofs, gof_limit=gof_limit) + n_use = int(use_mask.sum()) # explicit int cast for itertools.permutations + dev_pnts_tmp = dev_pnts[use_mask] + # First pass: figure out best order using the good dev points + for this_order in itertools.permutations(np.arange(len(head_pnts)), n_use): head_pnts_tmp = head_pnts[np.array(this_order)] - this_quat, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp) + this_quat, g = _fit_chpi_quat(dev_pnts_tmp, head_pnts_tmp) assert np.linalg.det(quat_to_rot(this_quat[:3])) > 0.9999 if bias: # For symmetrical arrangements, flips can produce roughly @@ -612,17 +622,35 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix="") if check_g > best_g: out_g = g best_g = check_g - best_order = np.array(this_order) + best_order[use_mask] = this_order best_quat = this_quat + del this_order + # Second pass: now fit the remaining (bad) coils using the best order and quat + # from above + missing = np.setdiff1d(np.arange(n_coils), best_order[best_order >= 0]) + best_missing_g = -np.inf + for this_order in itertools.permutations(missing): + full_order = best_order.copy() + full_order[~use_mask] = this_order + assert (full_order >= 0).all() + assert np.array_equal(np.sort(full_order), np.arange(n_coils)) + head_pnts_tmp = head_pnts[np.array(full_order)] + _, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp, quat=best_quat) + if g > best_missing_g: + best_missing_g = g + best_order[:] = full_order + del this_order + assert np.array_equal(np.sort(best_order), np.arange(n_coils)) # Convert Quaterion to transform dev_head_t = _quat_to_affine(best_quat) ang, dist = angle_distance_between_rigid( dev_head_t, angle_units="deg", distance_units="mm" ) + extra = f" using {n_use}/{n_coils} coils" if n_use < n_coils else "" logger.info( f"{prefix}Fitted dev_head_t {ang:0.1f}° and {dist:0.1f} mm " - f"from device origin (GOF: {out_g:.3f})" + f"from device origin{extra} (GOF: {out_g:.3f})" ) return dev_head_t, best_order, out_g @@ -1703,7 +1731,8 @@ def refit_hpi( :func:`~mne.chpi.compute_chpi_locs`. 3. Optionally determine coil digitization order by testing all permutations for the best goodness of fit between digitized coil locations and - (rigid-transformed) fitted coil locations. + (rigid-transformed) fitted coil locations, choosing the order first based on + those that satisfy ``gof_limit`` then the others. 4. Subselect coils to use for fitting ``dev_head_t`` based on ``gof_limit``, ``dist_limit``, and ``use``. 5. Update info inplace by modifying ``info["dev_head_t"]`` and appending new entries @@ -1816,6 +1845,8 @@ def refit_hpi( fit_dev_head_t, fit_order, _g = _fit_coil_order_dev_head_trans( hpi_dev, hpi_head, + gofs=hpi_gofs, + gof_limit=gof_limit, prefix=" ", ) else: @@ -1824,27 +1855,21 @@ def refit_hpi( # 4. Subselect usable coils and determine final dev_head_t if isinstance(use, int) or use is None: - used = np.where(hpi_gofs >= gof_limit)[0] - if len(used) < 3: - gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs) - raise RuntimeError( - f"Only {len(used)} coil{_pl(used)} with goodness of fit >= {gof_limit}" - f", need at least 3 to refit HPI order (got {gofs})." - ) - quat, _g = _fit_chpi_quat(hpi_dev[used], hpi_head[fit_order][used]) + use_mask = _gof_use_mask(hpi_gofs, gof_limit=gof_limit) + quat, _g = _fit_chpi_quat(hpi_dev[use_mask], hpi_head[fit_order][use_mask]) fit_dev_head_t = _quat_to_affine(quat) hpi_head_got = apply_trans(fit_dev_head_t, hpi_dev) dists = np.linalg.norm(hpi_head_got - hpi_head[fit_order], axis=1) dist_str = " ".join(f"{dist * 1e3:.1f}" for dist in dists) logger.info(f" Coil distances after initial fit: {dist_str} mm") - good_dists_idx = np.where(dists[used] <= dist_limit)[0] + good_dists_idx = np.where(dists[use_mask] <= dist_limit)[0] if not len(good_dists_idx) >= 3: raise RuntimeError( - f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} have distance " + f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} with distance " f"<= {dist_limit * 1e3:.1f} mm, need at least 3 to refit HPI order " f"(got distances: {np.round(1e3 * dists, 1)})." ) - used = used[good_dists_idx] + used = np.where(use_mask)[0][good_dists_idx] if use is not None: used = np.sort(used[np.argsort(hpi_gofs[used])[-use:]]) else: @@ -1927,6 +1952,19 @@ def refit_hpi( return info +def _gof_use_mask(hpi_gofs, *, gof_limit): + assert isinstance(hpi_gofs, np.ndarray) and hpi_gofs.ndim == 1 + use_mask = hpi_gofs >= gof_limit + n_use = use_mask.sum() + if n_use < 3: + gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs) + raise RuntimeError( + f"Only {n_use} coil{_pl(n_use)} with goodness of fit >= {gof_limit}" + f", need at least 3 to refit HPI order (got {gofs})." + ) + return use_mask + + def _sorted_hpi_dig(dig, *, kinds=(FIFF.FIFFV_POINT_HPI,)): return sorted( # need .get here because the hpi_result["dig_points"] does not set it diff --git a/mne/datasets/config.py b/mne/datasets/config.py index 23c1cf9e78b..ca65910dda6 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.169", + testing="0.170", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:bb0524db8605e96fde6333893a969766", + hash="md5:ebd873ea89507cf5a75043f56119d22b", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py index 0ba13f8c708..ec0d9c3c70f 100644 --- a/mne/tests/test_chpi.py +++ b/mne/tests/test_chpi.py @@ -73,6 +73,7 @@ ctf_chpi_fname = data_path / "CTF" / "testdata_ctf_mc.ds" ctf_chpi_pos_fname = data_path / "CTF" / "testdata_ctf_mc.pos" chpi_problem_fname = data_path / "SSS" / "chpi_problematic-info.fif" +chpi_bad_gof_fname = data_path / "SSS" / "chpi_bad_gof-info.fif" art_fname = ( data_path @@ -1011,3 +1012,19 @@ def test_refit_hpi_locs_problematic(): ) assert 3 < ang < 6 assert 82 < dist < 87 + + +@testing.requires_testing_data +def test_refit_hpi_locs_bad_gof(): + """Test that we can handle bad GOF HPI fits.""" + # gh-13524 + info = read_info(chpi_bad_gof_fname) + assert_array_equal(info["hpi_results"][-1]["used"], [2, 3, 4]) + info_new = refit_hpi(info.copy(), amplitudes=False, locs=False) + assert_array_equal(info_new["hpi_results"][-1]["used"], [1, 2, 3, 4]) + assert_trans_allclose( + info["dev_head_t"], + info_new["dev_head_t"], + dist_tol=1e-3, + angle_tol=1, + ) From f9f2895c7de2f3979530f43367e0488668a16b75 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 15 Dec 2025 10:38:09 -0500 Subject: [PATCH 16/30] MAINT: Update for sklearn deprecation (#13545) --- examples/decoding/decoding_xdawn_eeg.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/decoding/decoding_xdawn_eeg.py b/examples/decoding/decoding_xdawn_eeg.py index 1d1bf3f8760..a7d70bcb5bb 100644 --- a/examples/decoding/decoding_xdawn_eeg.py +++ b/examples/decoding/decoding_xdawn_eeg.py @@ -30,6 +30,7 @@ from mne import Epochs, io, pick_types, read_events from mne.datasets import sample from mne.decoding import Vectorizer, XdawnTransformer, get_spatial_filter_from_estimator +from mne.utils import check_version print(__doc__) @@ -70,11 +71,16 @@ ) # Create classification pipeline +kwargs = dict() +if check_version("sklearn", "1.8"): + kwargs["l1_ratio"] = 1 +else: + kwargs["penalty"] = "l1" clf = make_pipeline( XdawnTransformer(n_components=n_filter), Vectorizer(), MinMaxScaler(), - OneVsRestClassifier(LogisticRegression(penalty="l1", solver="liblinear")), + OneVsRestClassifier(LogisticRegression(solver="liblinear", **kwargs)), ) # Get the data and labels From 5a5b597d1eda4875e4bc38ac855e24b0644cd69d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Dec 2025 22:24:39 +0000 Subject: [PATCH 17/30] [pre-commit.ci] pre-commit autoupdate (#13546) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: mne[bot] <50266005+mne-bot@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd3448a5ed6..5ca5e30ec1c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.8 + rev: v0.14.9 hooks: - id: ruff-check name: ruff lint mne From 80d2d8364e86b4f1f9ad192930b6ae73c11cec27 Mon Sep 17 00:00:00 2001 From: Michael Straube Date: Wed, 17 Dec 2025 21:36:51 +0100 Subject: [PATCH 18/30] FIX: Fix axis spines zorder in plot_evoked (#13549) Co-authored-by: Daniel McCloy --- doc/changes/dev/13549.bugfix.rst | 1 + mne/viz/evoked.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 doc/changes/dev/13549.bugfix.rst diff --git a/doc/changes/dev/13549.bugfix.rst b/doc/changes/dev/13549.bugfix.rst new file mode 100644 index 00000000000..5c0d111f40c --- /dev/null +++ b/doc/changes/dev/13549.bugfix.rst @@ -0,0 +1 @@ +Fix bug with :func:`mne.viz.plot_evoked` where channels were plotted above axis spines, by `Michael Straube`_. diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 58265ccd37e..c12c1f0945e 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -808,6 +808,10 @@ def _plot_lines( # Put back the y limits as fill_betweenx messes them up ax.set_ylim(this_ylim) + # Ensure the axis spines are drawn above all Line2D artists + max_zorder = max((line.get_zorder() for line in ax.get_lines()), default=0) + 1 + ax.spines[:].set_zorder(max_zorder) + lines.append(line_list) if selectable: From 69853f7c4fac01b796439c3717c968a1a789dc2b Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Fri, 19 Dec 2025 13:43:20 +0900 Subject: [PATCH 19/30] Remove reference to non-public mod in doc/changes/dev/13408.newfeature.rst Co-authored-by: Eric Larson --- doc/changes/dev/13408.newfeature.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/dev/13408.newfeature.rst b/doc/changes/dev/13408.newfeature.rst index 9ac833d7027..b53a08ca6ce 100644 --- a/doc/changes/dev/13408.newfeature.rst +++ b/doc/changes/dev/13408.newfeature.rst @@ -1,2 +1,2 @@ -Add support for multi-wavelength NIRS processing to :func:`mne.preprocessing.nirs.beer_lambert_law`, :func:`mne.preprocessing.nirs.scalp_coupling_index`, :mod:`mne.preprocessing.nirs.nirs`, and SNIRF reader :class:`mne.io.snirf._snirf.RawSNIRF`, by :newcontrib:`Tamas Fehervari`. +Add support for multi-wavelength NIRS processing to :func:`mne.preprocessing.nirs.beer_lambert_law`, :func:`mne.preprocessing.nirs.scalp_coupling_index`, and SNIRF reader :class:`mne.io.snirf._snirf.RawSNIRF`, by :newcontrib:`Tamas Fehervari`. ` From 0175028a613ca7deb0f7bd170e5919c061d96726 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Fri, 26 Dec 2025 23:40:46 +0800 Subject: [PATCH 20/30] multi-wave testing in io and preprocessing; missing test_nirs.py --- mne/io/snirf/tests/test_snirf.py | 102 ++++++++++++++++-- mne/preprocessing/nirs/tests/conftest.py | 38 +++++++ .../nirs/tests/test_beer_lambert_law.py | 23 ++++ .../nirs/tests/test_optical_density.py | 19 ++++ .../nirs/tests/test_scalp_coupling_index.py | 85 +++++++++++++++ 5 files changed, 260 insertions(+), 7 deletions(-) create mode 100644 mne/preprocessing/nirs/tests/conftest.py diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 73e3c775ed1..462970eb391 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -5,16 +5,23 @@ import datetime import shutil from contextlib import nullcontext +from pathlib import Path import numpy as np import pytest -from numpy.testing import assert_allclose, assert_almost_equal, assert_equal +from numpy.testing import ( + assert_allclose, + assert_almost_equal, + assert_array_equal, + assert_equal, +) from mne._fiff.constants import FIFF from mne.datasets.testing import data_path, requires_testing_data from mne.io import read_raw_nirx, read_raw_snirf from mne.io.tests.test_raw import _test_raw_reader from mne.preprocessing.nirs import ( + _channel_frequencies, _reorder_nirx, beer_lambert_law, optical_density, @@ -69,10 +76,68 @@ lumo110 = testing_path / "SNIRF" / "GowerLabs" / "lumomat-1-1-0.snirf" +@pytest.fixture(name="multi_wavelength_snirf_fname", scope="module") +def fixture_multi_wavelength_snirf_fname(tmp_path_factory): + """Return path to a tiny 3-wavelength SNIRF file for io tests.""" + try: + snirf = pytest.importorskip("snirf") + except AttributeError as exc: + # Until https://github.com/BUNPC/pysnirf2/pull/43 is released + pytest.skip(f"snirf import error: {exc}") + out_dir = Path(tmp_path_factory.mktemp("snirf_multi")) + fname = out_dir / "test_multiwl.snirf" + if fname.exists(): + return fname + + # 32 mwasurements with 2 source-detector pairs, each with 3 wavelengths + n_times = 32 + n_freq = 3 + n_channels = 2 * n_freq + + with snirf.Snirf(str(fname), "w") as f: + f.nirs.appendGroup() + f.nirs[0].data.appendGroup() + f.nirs[0].data[0].dataTimeSeries = np.ones((n_times, n_channels)) + f.nirs[0].data[0].time = range(n_times) + for ii in range(n_channels): + f.nirs[0].data[0].measurementList.appendGroup() + f.nirs[0].data[0].measurementList[ii].sourceIndex = ii // n_freq + 1 + f.nirs[0].data[0].measurementList[ii].detectorIndex = ii // n_freq + 1 + f.nirs[0].data[0].measurementList[ii].wavelengthIndex = (ii % 3) + 1 + f.nirs[0].data[0].measurementList[ii].dataType = 1 + f.nirs[0].data[0].measurementList[ii].dataTypeIndex = 0 + f.nirs[0].metaDataTags.SubjectID = "multi" + f.nirs[0].metaDataTags.MeasurementDate = "2000-01-01" + f.nirs[0].metaDataTags.MeasurementTime = "00:00:00" + f.nirs[0].metaDataTags.LengthUnit = "m" + f.nirs[0].metaDataTags.TimeUnit = "s" + f.nirs[0].metaDataTags.FrequencyUnit = "Hz" + f.nirs[0].probe.wavelengths = [700 + x * 30 for x in range(n_freq)] + f.nirs[0].probe.sourcePos3D = [[0.01 * x, 0.0, 0.0] for x in range(n_channels)] + f.nirs[0].probe.detectorPos3D = [ + [0.01 * x, 0.02, 0.0] for x in range(n_channels) + ] + f.save() + + assert fname.exists() + return fname + + def _get_loc(raw, ch_name): return raw.copy().pick(ch_name).info["chs"][0]["loc"] +def _run_basic_processing(fname): + raw = read_raw_snirf(fname, preload=True) + if "fnirs_cw_amplitude" in raw: + raw = optical_density(raw) + if "fnirs_od" in raw: + raw = beer_lambert_law(raw, ppf=6) + assert "hbo" in raw + assert "hbr" in raw + return raw + + @requires_testing_data @pytest.mark.filterwarnings("ignore:.*contains 2D location.*:") @pytest.mark.filterwarnings("ignore:.*measurement date.*:") @@ -93,14 +158,24 @@ def _get_loc(raw, ch_name): ) def test_basic_reading_and_min_process(fname): """Test reading SNIRF files and minimum typical processing.""" - raw = read_raw_snirf(fname, preload=True) - # SNIRF data can contain several types, so only apply appropriate functions - if "fnirs_cw_amplitude" in raw: - raw = optical_density(raw) - if "fnirs_od" in raw: - raw = beer_lambert_law(raw, ppf=6) + _run_basic_processing(fname) + + +def test_basic_reading_and_min_process_multi(multi_wavelength_snirf_fname): + """Ensure synthetic multi-wavelength SNIRF file passes basic processing. + + Same tests as in _run_basic_processing but with checks for number of channels. + """ + raw = read_raw_snirf(multi_wavelength_snirf_fname, preload=True) + assert "fnirs_cw_amplitude" in raw + assert len(raw.ch_names) == 6 + raw = optical_density(raw) + assert "fnirs_od" in raw + assert len(raw.ch_names) == 6 + raw = beer_lambert_law(raw, ppf=6) assert "hbo" in raw assert "hbr" in raw + assert len(raw.ch_names) == 4 @requires_testing_data @@ -574,3 +649,16 @@ def test_sample_rate_jitter(tmp_path): f.create_dataset("nirs/data1/time", data=unacceptable_time_jitter) with pytest.warns(RuntimeWarning, match="non-uniformly-sampled data"): read_raw_snirf(new_file, verbose=True) + + +def test_snirf_multiple_wavelengths(multi_wavelength_snirf_fname): + """Test importing synthetic SNIRF files with >=3 wavelengths.""" + raw = read_raw_snirf(multi_wavelength_snirf_fname, preload=True) + assert raw._data.shape == (6, 32) + assert raw.info["sfreq"] == pytest.approx(1.0) + assert raw.info["ch_names"][:3] == ["S1_D1 700", "S1_D1 730", "S1_D1 760"] + assert len(raw.ch_names) == 6 + freqs = np.unique(_channel_frequencies(raw.info)) + assert_array_equal(freqs, [700, 730, 760]) + distances = source_detector_distances(raw.info) + assert len(distances) == len(raw.ch_names) diff --git a/mne/preprocessing/nirs/tests/conftest.py b/mne/preprocessing/nirs/tests/conftest.py new file mode 100644 index 00000000000..f99c1cc5c76 --- /dev/null +++ b/mne/preprocessing/nirs/tests/conftest.py @@ -0,0 +1,38 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause + +from __future__ import annotations + +import numpy as np +import pytest + +from mne import create_info +from mne.io import RawArray + + +@pytest.fixture +def multi_wavelength_raw(request: pytest.FixtureRequest) -> RawArray: + """Create a raw CW fNIRS object with 3 wavelengths per source-detector pair.""" + n_pairs = getattr(request, "param", None) + if n_pairs is None: + raise RuntimeError( + "parametrize multi_wavelength_raw with the desired number of optode pairs" + ) + sampling_freq = 10.0 + n_times = 128 + freqs = [700, 730, 850] + + ch_names = [f"S{ii}_D{ii} {wl}" for ii in range(1, n_pairs + 1) for wl in freqs] + rng = np.random.default_rng() + data = rng.random((len(ch_names), n_times)) + 0.01 + + info = create_info( + ch_names=ch_names, ch_types="fnirs_cw_amplitude", sfreq=sampling_freq + ) + raw = RawArray(data, info, verbose=True) + for ii, (ch, freq) in enumerate(zip(raw.info["chs"], freqs * n_pairs)): + ch["loc"][9] = freq + ch["loc"][3:6] = (ii // 3 * 0.01, 0.0, 0.0) + ch["loc"][6:9] = (ii // 3 * 0.01, 0.03, 0.0) + + return raw diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index d80cd5fd7a7..2acc898f549 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -103,3 +103,26 @@ def test_beer_lambert_v_matlab(): + matlab_data["type"][idx] ) assert raw.info["ch_names"][idx] == matlab_name + + +@pytest.mark.parametrize("multi_wavelength_raw", [2], indirect=True) +def test_beer_lambert_multi_wavelength(multi_wavelength_raw): + """Ensure Beer-Lambert can process >=3 wavelengths and reduces to 2 channels.""" + # Verify original CW data + raw = multi_wavelength_raw.copy() + assert len(raw.ch_names) == 2 * 3 + assert raw.ch_names[0] == "S1_D1 700" + assert raw.ch_names[5] == "S2_D2 850" + assert set(raw.get_channel_types()) == {"fnirs_cw_amplitude"} + + # Convert to OD (tested elsewhere) + raw = optical_density(raw) + + # Verify data after conversion to Hb; channel numbers reduced to 2 per pair + raw = beer_lambert_law(raw) + _validate_type(raw, BaseRaw, "raw") + assert len(raw.ch_names) == 2 * 2 + assert all(name.endswith(" hbo") or name.endswith(" hbr") for name in raw.ch_names) + assert raw.ch_names[0] == "S1_D1 hbo" + assert raw.ch_names[3] == "S2_D2 hbr" + assert set(raw.get_channel_types()) == {"hbo", "hbr"} diff --git a/mne/preprocessing/nirs/tests/test_optical_density.py b/mne/preprocessing/nirs/tests/test_optical_density.py index 89b9edce713..06accfea601 100644 --- a/mne/preprocessing/nirs/tests/test_optical_density.py +++ b/mne/preprocessing/nirs/tests/test_optical_density.py @@ -31,6 +31,25 @@ def test_optical_density(): optical_density(raw) +@pytest.mark.parametrize("multi_wavelength_raw", [2], indirect=True) +def test_optical_density_multi_wavelength(multi_wavelength_raw): + """Ensure OD can process >=3 wavelengths and preserves channels.""" + # Validate original CW data + raw = multi_wavelength_raw.copy() + assert len(raw.ch_names) == 2 * 3 + assert raw.ch_names[0] == "S1_D1 700" + assert raw.ch_names[5] == "S2_D2 850" + assert set(raw.get_channel_types()) == {"fnirs_cw_amplitude"} + + # Validate that data has been converted to OD, number of channels preserved + raw = optical_density(raw) + _validate_type(raw, BaseRaw, "raw") + assert len(raw.ch_names) == 2 * 3 + assert raw.ch_names[0] == "S1_D1 700" + assert raw.ch_names[5] == "S2_D2 850" + assert set(raw.get_channel_types()) == {"fnirs_od"} + + @testing.requires_testing_data def test_optical_density_zeromean(): """Test that optical density can process zero mean data.""" diff --git a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py index 4a8fd3e71d0..d7f1965208d 100644 --- a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py @@ -76,3 +76,88 @@ def test_scalp_coupling_index(fname, fmt, tmp_path): raw = beer_lambert_law(raw, ppf=6) with pytest.raises(RuntimeError, match="Scalp"): scalp_coupling_index(raw) + + +@pytest.mark.parametrize("multi_wavelength_raw", [12], indirect=True) +def test_scalp_coupling_index_multi_wavelength(multi_wavelength_raw): + """Validate SCI min-correlation logic for >=3 wavelengths. + + Similar to test in test_scalp_coupling_index, considers cases + specific to multi-wavelength data. Uses the `multi_wavelength_raw` + fixture to generate CW nirs data with the requested number of + channels (S-D optode pairs), each with 3 wavelengths; in total + n_channels x 3 data vectors. + """ + raw = optical_density(multi_wavelength_raw.copy()) + assert len(raw.ch_names) == 12 * 3 + assert raw.ch_names[0] == "S1_D1 700" + times = np.arange(raw.n_times) / raw.info["sfreq"] + signal = np.sin(2 * np.pi * 1.0 * times) + 1 + rng = np.random.default_rng() + + # pre-determined expected results + expected = [] + # group 1: perfect correlation; sci = 1 + raw._data[0] = signal + raw._data[1] = signal + raw._data[2] = signal + expected.extend([1.0] * 3) + # group 2: scale invariance; sci = 1 + raw._data[3] = signal + raw._data[4] = signal * 0.3 + raw._data[5] = signal + expected.extend([1.0] * 3) + # group 3: anti-correlation; minimum value taken, sci = -1 + raw._data[6] = signal + raw._data[7] = signal + raw._data[8] = -signal + expected.extend([-1.0] * 3) + # group 4: one zero std channel; minimum value is sci = 0 + raw._data[9] = 0.0 + raw._data[10] = signal + raw._data[11] = signal + expected.extend([0.0] * 3) + # group 5: three zero std channels; all sci = 0 + raw._data[12] = 0.0 + raw._data[13] = 1.0 + raw._data[14] = 2.0 + expected.extend([0.0] * 3) + # group 6: mixed: 1 signal + 1 negative + 1 random (lowest wins) + raw._data[15] = signal + raw._data[16] = rng.random(signal.shape) + raw._data[17] = -signal + expected.extend([-1.0] * 3) + + # exact results unknown + # group 7: 1 uncorrelated signal out of 3; sci < 0.5 + raw._data[18] = signal + raw._data[19] = rng.random(signal.shape) + raw._data[20] = signal + # group 8: 2 uncorrelated signals out of 3; sci < 0.5 + raw._data[21] = rng.random(signal.shape) + raw._data[22] = rng.random(signal.shape) + raw._data[23] = signal + # group 9: 3 uncorrelated signals; sci < 0.5 + raw._data[24] = rng.random(signal.shape) + raw._data[25] = rng.random(signal.shape) + raw._data[26] = rng.random(signal.shape) + # groups 10-12: ordering invariance; all must be the same + rand1 = rng.random(signal.shape) + rand2 = rng.random(signal.shape) + rand3 = rng.random(signal.shape) + raw._data[27] = rand1 + raw._data[28] = rand2 + raw._data[29] = rand3 + raw._data[30] = rand2 + raw._data[31] = rand1 + raw._data[32] = rand3 + raw._data[33] = rand3 + raw._data[34] = rand1 + raw._data[35] = rand2 + + sci = scalp_coupling_index(raw) + + assert_allclose(sci[:18], expected, atol=1e-4) + for ii in range(18, 27): + assert np.abs(sci[ii]) < 0.5 + assert_allclose(sci[28:], sci[27], atol=1e-4) From 49b4dd91bcb5bf30db54448bfc6dc1092081e884 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Sat, 27 Dec 2025 00:30:10 +0800 Subject: [PATCH 21/30] fixed some unwanted changes leftover in test_snirf.py --- mne/io/snirf/tests/test_snirf.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 462970eb391..33113567aab 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -127,17 +127,6 @@ def _get_loc(raw, ch_name): return raw.copy().pick(ch_name).info["chs"][0]["loc"] -def _run_basic_processing(fname): - raw = read_raw_snirf(fname, preload=True) - if "fnirs_cw_amplitude" in raw: - raw = optical_density(raw) - if "fnirs_od" in raw: - raw = beer_lambert_law(raw, ppf=6) - assert "hbo" in raw - assert "hbr" in raw - return raw - - @requires_testing_data @pytest.mark.filterwarnings("ignore:.*contains 2D location.*:") @pytest.mark.filterwarnings("ignore:.*measurement date.*:") @@ -158,24 +147,25 @@ def _run_basic_processing(fname): ) def test_basic_reading_and_min_process(fname): """Test reading SNIRF files and minimum typical processing.""" - _run_basic_processing(fname) - + raw = read_raw_snirf(fname, preload=True) + # SNIRF data can contain several types, so only apply appropriate functions + if "fnirs_cw_amplitude" in raw: + raw = optical_density(raw) + if "fnirs_od" in raw: + raw = beer_lambert_law(raw, ppf=6) + assert "hbo" in raw + assert "hbr" in raw -def test_basic_reading_and_min_process_multi(multi_wavelength_snirf_fname): - """Ensure synthetic multi-wavelength SNIRF file passes basic processing. - Same tests as in _run_basic_processing but with checks for number of channels. - """ +def test_basic_reading_and_min_process_multiwl(multi_wavelength_snirf_fname): + """Ensure synthetic multi-wavelength SNIRF file passes basic processing.""" raw = read_raw_snirf(multi_wavelength_snirf_fname, preload=True) assert "fnirs_cw_amplitude" in raw - assert len(raw.ch_names) == 6 raw = optical_density(raw) assert "fnirs_od" in raw - assert len(raw.ch_names) == 6 raw = beer_lambert_law(raw, ppf=6) assert "hbo" in raw assert "hbr" in raw - assert len(raw.ch_names) == 4 @requires_testing_data From 77259d1fc4cebe24c6048b8ec1315fd1bcff1b24 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Tue, 30 Dec 2025 11:29:23 +0800 Subject: [PATCH 22/30] multi-wavelength-aware testing in io/snirf and preprocessing/nirs, now based on labnirs sample file --- mne/preprocessing/nirs/nirs.py | 2 +- mne/preprocessing/nirs/tests/conftest.py | 38 ---- .../nirs/tests/test_beer_lambert_law.py | 138 ++++++------- mne/preprocessing/nirs/tests/test_nirs.py | 187 +++++++++++++----- .../nirs/tests/test_optical_density.py | 50 +++-- .../nirs/tests/test_scalp_coupling_index.py | 21 +- 6 files changed, 234 insertions(+), 202 deletions(-) delete mode 100644 mne/preprocessing/nirs/tests/conftest.py diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index afd4fbaef8d..49827fd1df5 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -145,7 +145,7 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr ) # Check that the total number of channels is divisible by the number of pair values - # (e.g., for 2 wavelengths, we need even number of channels) + # (e.g., for 2 wavelengths, we need an even number of channels) if len(picks) % len(pair_vals) != 0: picks = _throw_or_return_empty( "NIRS channels not ordered correctly. The number of channels " diff --git a/mne/preprocessing/nirs/tests/conftest.py b/mne/preprocessing/nirs/tests/conftest.py deleted file mode 100644 index f99c1cc5c76..00000000000 --- a/mne/preprocessing/nirs/tests/conftest.py +++ /dev/null @@ -1,38 +0,0 @@ -# Authors: The MNE-Python contributors. -# License: BSD-3-Clause - -from __future__ import annotations - -import numpy as np -import pytest - -from mne import create_info -from mne.io import RawArray - - -@pytest.fixture -def multi_wavelength_raw(request: pytest.FixtureRequest) -> RawArray: - """Create a raw CW fNIRS object with 3 wavelengths per source-detector pair.""" - n_pairs = getattr(request, "param", None) - if n_pairs is None: - raise RuntimeError( - "parametrize multi_wavelength_raw with the desired number of optode pairs" - ) - sampling_freq = 10.0 - n_times = 128 - freqs = [700, 730, 850] - - ch_names = [f"S{ii}_D{ii} {wl}" for ii in range(1, n_pairs + 1) for wl in freqs] - rng = np.random.default_rng() - data = rng.random((len(ch_names), n_times)) + 0.01 - - info = create_info( - ch_names=ch_names, ch_types="fnirs_cw_amplitude", sfreq=sampling_freq - ) - raw = RawArray(data, info, verbose=True) - for ii, (ch, freq) in enumerate(zip(raw.info["chs"], freqs * n_pairs)): - ch["loc"][9] = freq - ch["loc"][3:6] = (ii // 3 * 0.01, 0.0, 0.0) - ch["loc"][6:9] = (ii // 3 * 0.01, 0.03, 0.0) - - return raw diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index 2acc898f549..aa793ff3bbe 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -7,8 +7,12 @@ from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import BaseRaw, read_raw_fif, read_raw_nirx -from mne.preprocessing.nirs import beer_lambert_law, optical_density +from mne.io import BaseRaw, read_raw_fif, read_raw_nirx, read_raw_snirf +from mne.preprocessing.nirs import ( + _channel_frequencies, + beer_lambert_law, + optical_density, +) from mne.utils import _validate_type testing_path = data_path(download=False) @@ -17,64 +21,67 @@ fname_nirx_15_2_short = ( testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) +fname_labnirs_multi_wavelength = ( + testing_path / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname,fmt", + ( + [ + (fname_nirx_15_2_short, "nirx"), + (fname_nirx_15_2_short, "fif"), + (fname_nirx_15_2, "nirx"), + (fname_nirx_15_2, "fif"), + (fname_nirx_15_0, "nirx"), + (fname_nirx_15_0, "fif"), + (fname_labnirs_multi_wavelength, "snirf"), + ] + ), ) -@pytest.mark.parametrize("fmt", ("nirx", "fif")) def test_beer_lambert(fname, fmt, tmp_path): - """Test converting NIRX files.""" - assert fmt in ("nirx", "fif") - raw = read_raw_nirx(fname) - if fmt == "fif": - raw.save(tmp_path / "test_raw.fif") - raw = read_raw_fif(tmp_path / "test_raw.fif") - assert "fnirs_cw_amplitude" in raw - assert "fnirs_od" not in raw - raw = optical_density(raw) - _validate_type(raw, BaseRaw, "raw") - assert "fnirs_cw_amplitude" not in raw - assert "fnirs_od" in raw - assert "hbo" not in raw - raw = beer_lambert_law(raw) - _validate_type(raw, BaseRaw, "raw") - assert "fnirs_cw_amplitude" not in raw - assert "fnirs_od" not in raw - assert "hbo" in raw - assert "hbr" in raw - - -@testing.requires_testing_data -def test_beer_lambert_unordered_errors(): - """NIRS data requires specific ordering and naming of channels.""" - raw = read_raw_nirx(fname_nirx_15_0) - raw_od = optical_density(raw) - raw_od.pick([0, 1, 2]) - with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): - beer_lambert_law(raw_od) - - # Test that an error is thrown if channel naming frequency doesn't match - # what is stored in loc[9], which should hold the light frequency too. - # Introduce 2 new frequencies to make it 4 in total vs 2 stored in loc[9]. - # This way the bad data will have 20 channels and 4 wavelengths, so as not - # to get caught by the check for divisibility (channel % wavelength == 0). - raw_od = optical_density(raw) - assert raw.ch_names[0] == "S1_D1 760" and raw.ch_names[1] == "S1_D1 850" - assert ( - raw_od.ch_names.index(raw.ch_names[0]) == 0 - and raw_od.ch_names.index(raw.ch_names[1]) == 1 - ) - raw_od.rename_channels( - { - raw.ch_names[0]: raw.ch_names[0].replace("760", "770"), - raw.ch_names[1]: raw.ch_names[1].replace("850", "840"), - } - ) - assert raw_od.ch_names[0] == "S1_D1 770" and raw_od.ch_names[1] == "S1_D1 840" - with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): - beer_lambert_law(raw_od) + """Test converting raw CW amplitude files.""" + match fmt: + case "nirx": + raw_volt = read_raw_nirx(fname) + case "fif": + raw_nirx = read_raw_nirx(fname) + raw_nirx.save(tmp_path / "test_raw.fif") + raw_volt = read_raw_fif(tmp_path / "test_raw.fif") + case "snirf": + raw_volt = read_raw_snirf(fname) + case _: + raise ValueError( + f"fmt expexted to be one of 'nirx', 'fif' or 'snirf', got {fmt}" + ) + + raw_od = optical_density(raw_volt) + _validate_type(raw_od, BaseRaw, "raw") + + raw_hb = beer_lambert_law(raw_od) + _validate_type(raw_hb, BaseRaw, "raw") + + # Verify channel numbers (multi-wavelength aware) + # Raw voltage has: optode pairs * number of wavelengths + # OD must have the same number as raw voltage + # Hb data must have: number of optode pairs * 2 + nfreqs = len(set(_channel_frequencies(raw_volt.info))) + assert len(raw_volt.ch_names) % nfreqs == 0 + npairs = len(raw_volt.ch_names) // nfreqs + assert len(raw_hb.ch_names) % npairs == 0 + assert len(raw_hb.ch_names) // npairs == 2.0 + + # Verify data types + assert set(raw_volt.get_channel_types()) == {"fnirs_cw_amplitude"} + assert set(raw_hb.get_channel_types()) == {"hbo", "hbr"} + + # Verify that pair ordering did not change just channel name suffixes + old_prefixes = [name.split(" ")[0] for name in raw_volt.ch_names[::nfreqs]] + new_prefixes = [name.split(" ")[0] for name in raw_hb.ch_names[::2]] + assert old_prefixes == new_prefixes + assert all([name.split(" ")[1] in {"hbo", "hbr"} for name in raw_hb.ch_names]) @testing.requires_testing_data @@ -103,26 +110,3 @@ def test_beer_lambert_v_matlab(): + matlab_data["type"][idx] ) assert raw.info["ch_names"][idx] == matlab_name - - -@pytest.mark.parametrize("multi_wavelength_raw", [2], indirect=True) -def test_beer_lambert_multi_wavelength(multi_wavelength_raw): - """Ensure Beer-Lambert can process >=3 wavelengths and reduces to 2 channels.""" - # Verify original CW data - raw = multi_wavelength_raw.copy() - assert len(raw.ch_names) == 2 * 3 - assert raw.ch_names[0] == "S1_D1 700" - assert raw.ch_names[5] == "S2_D2 850" - assert set(raw.get_channel_types()) == {"fnirs_cw_amplitude"} - - # Convert to OD (tested elsewhere) - raw = optical_density(raw) - - # Verify data after conversion to Hb; channel numbers reduced to 2 per pair - raw = beer_lambert_law(raw) - _validate_type(raw, BaseRaw, "raw") - assert len(raw.ch_names) == 2 * 2 - assert all(name.endswith(" hbo") or name.endswith(" hbr") for name in raw.ch_names) - assert raw.ch_names[0] == "S1_D1 hbo" - assert raw.ch_names[3] == "S2_D2 hbr" - assert set(raw.get_channel_types()) == {"hbo", "hbr"} diff --git a/mne/preprocessing/nirs/tests/test_nirs.py b/mne/preprocessing/nirs/tests/test_nirs.py index 89fa17c0c8d..069657e2501 100644 --- a/mne/preprocessing/nirs/tests/test_nirs.py +++ b/mne/preprocessing/nirs/tests/test_nirs.py @@ -11,7 +11,7 @@ from mne._fiff.pick import _picks_to_idx from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import RawArray, read_raw_nirx +from mne.io import RawArray, read_raw_nirx, read_raw_snirf from mne.preprocessing.nirs import ( _channel_chromophore, _channel_frequencies, @@ -35,12 +35,22 @@ fname_nirx_15_2_short = ( data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) +fname_labnirs_multi_wavelength = ( + data_path(download=False) / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) @testing.requires_testing_data -def test_fnirs_picks(): +@pytest.mark.parametrize( + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], +) +def test_fnirs_picks(fname, readerfn): """Test picking of fnirs types after different conversions.""" - raw = read_raw_nirx(fname_nirx_15_0) + raw = readerfn(fname) picks = _picks_to_idx(raw.info, "fnirs_cw_amplitude") assert len(picks) == len(raw.ch_names) raw_subset = raw.copy().pick(picks="fnirs_cw_amplitude") @@ -106,12 +116,18 @@ def _fnirs_check_bads(info): @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], ) -def test_fnirs_check_bads(fname): +def test_fnirs_check_bads(fname, readerfn): """Test checking of bad markings.""" # No bad channels, so these should all pass - raw = read_raw_nirx(fname) + raw = readerfn(fname) _fnirs_check_bads(raw.info) raw = optical_density(raw) _fnirs_check_bads(raw.info) @@ -119,8 +135,9 @@ def test_fnirs_check_bads(fname): _fnirs_check_bads(raw.info) # Mark pairs of bad channels, so these should all pass - raw = read_raw_nirx(fname) - raw.info["bads"] = raw.ch_names[0:2] + raw = readerfn(fname) + nfreqs = len(set(_channel_frequencies(raw.info))) + raw.info["bads"] = raw.ch_names[0:nfreqs] _fnirs_check_bads(raw.info) raw = optical_density(raw) _fnirs_check_bads(raw.info) @@ -128,7 +145,7 @@ def test_fnirs_check_bads(fname): _fnirs_check_bads(raw.info) # Mark single channel as bad, so these should all fail - raw = read_raw_nirx(fname) + raw = readerfn(fname) raw.info["bads"] = raw.ch_names[0:1] pytest.raises(RuntimeError, _fnirs_check_bads, raw.info) with pytest.raises(RuntimeError, match="bad labelling"): @@ -144,71 +161,90 @@ def test_fnirs_check_bads(fname): @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], ) -def test_fnirs_spread_bads(fname): +def test_fnirs_spread_bads(fname, readerfn): """Test checking of bad markings.""" # Test spreading upwards in frequency and on raw data - raw = read_raw_nirx(fname) - raw.info["bads"] = ["S1_D1 760"] + raw = readerfn(fname) + nfreqs = len(set(_channel_frequencies(raw.info))) + raw.info["bads"] = [raw.ch_names[0]] info = _fnirs_spread_bads(raw.info) - assert info["bads"] == ["S1_D1 760", "S1_D1 850"] - - # Test spreading downwards in frequency and on od data - raw = optical_density(raw) - raw.info["bads"] = raw.ch_names[5:6] - info = _fnirs_spread_bads(raw.info) - assert info["bads"] == raw.ch_names[4:6] + assert info["bads"] == raw.ch_names[:nfreqs] + + # Test multiple spreading directions on od data + # For each wavelength, mark the nth item in the nth group as bad + # e.g. 3 wavelengths: + # group 0 (channels 0, 1, 2) - channel 0 is marked bad, + # group 1 (channels 3, 4, 5) - channel 4 is bad, + # group 2 (channels 6, 7, 8) - channel 8 is bad. + # This way we can test spreading from each group member in a + # way that's agnostic of the number of wavelengths, and avoid + # hard-coded values. Needs nfreqs**2 number of channels. + raw_od = optical_density(raw) + bads = [raw_od.ch_names[nfreqs * ii + ii] for ii in range(nfreqs)] + expected_bads = raw_od.ch_names[: nfreqs**2] + raw_od.info["bads"] = bads + info = _fnirs_spread_bads(raw_od.info) + # channels might not be sorted but the spreading result is + assert info["bads"] == sorted(expected_bads) # Test spreading multiple bads and on chroma data - raw = beer_lambert_law(raw) - raw.info["bads"] = [raw.ch_names[x] for x in [1, 8]] - info = _fnirs_spread_bads(raw.info) + # Hb data always has 2 channels per S-D pair, this works for any nfreqs + raw_hb = beer_lambert_law(raw_od) + raw_hb.info["bads"] = [raw_hb.ch_names[x] for x in [1, 8]] + info = _fnirs_spread_bads(raw_hb.info) assert info["bads"] == [info.ch_names[x] for x in [0, 1, 8, 9]] @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], ) -def test_fnirs_channel_naming_and_order_readers(fname): - """Ensure fNIRS channel checking on standard readers.""" +def test_fnirs_channel_frequency_ordering(fname, readerfn): + """Test fNIRS channel frequencies ordering and related errors.""" # fNIRS data requires specific channel naming and ordering. - # All standard readers should pass tests - raw = read_raw_nirx(fname) - freqs = np.unique(_channel_frequencies(raw.info)) - assert_array_equal(freqs, [760, 850]) + # Ensure that freqs are well-ordered after reading in from file, + # and that there are no chroma channels + raw = readerfn(fname) + freqs = np.unique(_channel_frequencies(raw.info)).tolist() + assert sorted(freqs) == freqs chroma = np.unique(_channel_chromophore(raw.info)) assert len(chroma) == 0 - picks = _check_channels_ordered(raw.info, freqs) assert len(picks) == len(raw.ch_names) # as all fNIRS only data - # Check that dropped channels are detected - # For each source detector pair there must be two channels, - # removing one should throw an error. - raw_dropped = raw.copy().drop_channels(raw.ch_names[4]) - with pytest.raises(ValueError, match="not ordered correctly"): - _check_channels_ordered(raw_dropped.info, freqs) - # The ordering must be increasing for the pairs, if provided - raw_names_reversed = raw.copy().ch_names - raw_names_reversed.reverse() - raw_reversed = raw.copy().pick(raw_names_reversed) + raw_reversed = raw.pick(list(reversed(raw.ch_names))) + with pytest.raises(ValueError, match="The frequencies.*sorted.*"): - _check_channels_ordered(raw_reversed.info, [850, 760]) + _check_channels_ordered(raw_reversed.info, list(reversed(freqs))) # So if we flip the second argument it should pass again picks = _check_channels_ordered(raw_reversed.info, freqs) - got_first = set(raw_reversed.ch_names[pick].split()[1] for pick in picks[::2]) - assert got_first == {"760"} - got_second = set(raw_reversed.ch_names[pick].split()[1] for pick in picks[1::2]) - assert got_second == {"850"} + nfreqs = len(freqs) + for ii in range(nfreqs): + suffixes = { + int(raw_reversed.ch_names[pick].split(" ")[1]) for pick in picks[ii::nfreqs] + } + assert suffixes == {freqs[ii]} # Check on OD data raw = optical_density(raw) - freqs = np.unique(_channel_frequencies(raw.info)) - assert_array_equal(freqs, [760, 850]) + freqs = np.unique(_channel_frequencies(raw.info)).tolist() + assert sorted(freqs) == freqs chroma = np.unique(_channel_chromophore(raw.info)) assert len(chroma) == 0 picks = _check_channels_ordered(raw.info, freqs) @@ -553,3 +589,58 @@ def test_order_agnostic(nirx_snirf): tddrs["nirx"].get_data(), r.get_data(orders[key]), err_msg=key, atol=1e-9 ) assert set(r.get_channel_types()) == {"hbo", "hbr"} + + +@testing.requires_testing_data +@pytest.mark.parametrize( + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], +) +def test_nirs_channel_grouping(fname, readerfn): + """Test channel grouping related errors.""" + raw = readerfn(fname) + freqs = np.unique(_channel_frequencies(raw.info)).tolist() + nfreqs = len(freqs) + + # Each source-detector (S-D) optode pair may have data measured at + # >=2 wavelengths. The channels that belong to an optode pair form + # a group, with a size equal to the number of wavelengths (nfreqs). + # An error is raised if these groups are incomplete. + + picks = _check_channels_ordered(raw.info, freqs) + assert len(picks) == len(raw.ch_names) + + # Removing one channel breaks the grouping + raw_dropped = raw.copy().drop_channels(raw.ch_names[4]) + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_dropped.info, freqs) + + # Selecting incomplete groups results in an error + raw_incomplete = raw.copy().pick(list(range(nfreqs + 1))) + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_incomplete.info, freqs) + + # Changing the frequency in one channel's name also breaks groups + raw_extrafreq = raw.copy() + new_ch10_name = f"{raw_extrafreq.ch_names[10].split(' ')[0]} 100" + raw_extrafreq.rename_channels({raw_extrafreq.ch_names[10]: new_ch10_name}) + print(raw_extrafreq.ch_names) + # checks result in error for both old and new set of frequencies + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_extrafreq.info, [100] + freqs) + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_extrafreq.info, freqs) + + # Frequency values are also stored in info['chs'][ii]['loc'][9]. + # Even though the actual values can be arbitrary, they cannot be None. + raw_locnone = raw.copy() + raw_locnone.info["chs"][10]["loc"][9] = None + with pytest.raises( + ValueError, match="NIRS channels is missing wavelength information" + ): + _check_channels_ordered(raw_locnone.info, freqs) diff --git a/mne/preprocessing/nirs/tests/test_optical_density.py b/mne/preprocessing/nirs/tests/test_optical_density.py index 06accfea601..53b66f46238 100644 --- a/mne/preprocessing/nirs/tests/test_optical_density.py +++ b/mne/preprocessing/nirs/tests/test_optical_density.py @@ -8,46 +8,42 @@ from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import BaseRaw, read_raw_nirx +from mne.io import BaseRaw, read_raw_nirx, read_raw_snirf from mne.preprocessing.nirs import optical_density from mne.utils import _validate_type fname_nirx = ( data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) +fname_labnirs_multi_wavelength = ( + data_path(download=False) / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) @testing.requires_testing_data -def test_optical_density(): +@pytest.mark.parametrize( + "fname,readerfn", + [(fname_nirx, read_raw_nirx), (fname_labnirs_multi_wavelength, read_raw_snirf)], +) +def test_optical_density(fname, readerfn): """Test return type for optical density.""" - raw = read_raw_nirx(fname_nirx, preload=False) - assert "fnirs_cw_amplitude" in raw - assert "fnirs_od" not in raw - raw = optical_density(raw) - _validate_type(raw, BaseRaw, "raw") - assert "fnirs_cw_amplitude" not in raw - assert "fnirs_od" in raw - with pytest.raises(RuntimeError, match="on continuous wave"): - optical_density(raw) + raw_volt = readerfn(fname, preload=False) + _validate_type(raw_volt, BaseRaw, "raw") + + raw_od = optical_density(raw_volt) + _validate_type(raw_od, BaseRaw, "raw") + # Verify data types + assert set(raw_volt.get_channel_types()) == {"fnirs_cw_amplitude"} + assert set(raw_od.get_channel_types()) == {"fnirs_od"} -@pytest.mark.parametrize("multi_wavelength_raw", [2], indirect=True) -def test_optical_density_multi_wavelength(multi_wavelength_raw): - """Ensure OD can process >=3 wavelengths and preserves channels.""" - # Validate original CW data - raw = multi_wavelength_raw.copy() - assert len(raw.ch_names) == 2 * 3 - assert raw.ch_names[0] == "S1_D1 700" - assert raw.ch_names[5] == "S2_D2 850" - assert set(raw.get_channel_types()) == {"fnirs_cw_amplitude"} + # Verify that channel names did not change + for oldname, newname in zip(raw_volt.ch_names, raw_od.ch_names): + assert oldname == newname - # Validate that data has been converted to OD, number of channels preserved - raw = optical_density(raw) - _validate_type(raw, BaseRaw, "raw") - assert len(raw.ch_names) == 2 * 3 - assert raw.ch_names[0] == "S1_D1 700" - assert raw.ch_names[5] == "S2_D2 850" - assert set(raw.get_channel_types()) == {"fnirs_od"} + # Cannot run OD conversion on OD data + with pytest.raises(RuntimeError, match="on continuous wave"): + optical_density(raw_od) @testing.requires_testing_data diff --git a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py index d7f1965208d..7035fa5376f 100644 --- a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py @@ -8,7 +8,7 @@ from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import read_raw_nirx +from mne.io import read_raw_nirx, read_raw_snirf from mne.preprocessing.nirs import ( beer_lambert_law, optical_density, @@ -24,6 +24,9 @@ fname_nirx_15_2_short = ( data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) +fname_labnirs_multi_wavelength = ( + data_path(download=False) / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) @testing.requires_testing_data @@ -78,21 +81,17 @@ def test_scalp_coupling_index(fname, fmt, tmp_path): scalp_coupling_index(raw) -@pytest.mark.parametrize("multi_wavelength_raw", [12], indirect=True) -def test_scalp_coupling_index_multi_wavelength(multi_wavelength_raw): +@testing.requires_testing_data +def test_scalp_coupling_index_multi_wavelength(): """Validate SCI min-correlation logic for >=3 wavelengths. Similar to test in test_scalp_coupling_index, considers cases - specific to multi-wavelength data. Uses the `multi_wavelength_raw` - fixture to generate CW nirs data with the requested number of - channels (S-D optode pairs), each with 3 wavelengths; in total - n_channels x 3 data vectors. + specific to multi-wavelength data. """ - raw = optical_density(multi_wavelength_raw.copy()) - assert len(raw.ch_names) == 12 * 3 - assert raw.ch_names[0] == "S1_D1 700" + raw = optical_density(read_raw_snirf(fname_labnirs_multi_wavelength)) times = np.arange(raw.n_times) / raw.info["sfreq"] signal = np.sin(2 * np.pi * 1.0 * times) + 1 + assert len(raw.ch_names) >= 15 * 3 rng = np.random.default_rng() # pre-determined expected results @@ -160,4 +159,4 @@ def test_scalp_coupling_index_multi_wavelength(multi_wavelength_raw): assert_allclose(sci[:18], expected, atol=1e-4) for ii in range(18, 27): assert np.abs(sci[ii]) < 0.5 - assert_allclose(sci[28:], sci[27], atol=1e-4) + assert_allclose(sci[28:36], sci[27], atol=1e-4) From d8df2d4194b6682306298b314639dd01c3fb13fb Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Tue, 30 Dec 2025 12:08:31 +0800 Subject: [PATCH 23/30] fix mne\io\snirf\tests\test_snirf.py to remove leftover fixture and use labnirs test file --- mne/io/snirf/tests/test_snirf.py | 76 +++++--------------------------- 1 file changed, 12 insertions(+), 64 deletions(-) diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 33113567aab..40fdf71c6f9 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -75,52 +75,10 @@ # GowerLabs lumo110 = testing_path / "SNIRF" / "GowerLabs" / "lumomat-1-1-0.snirf" - -@pytest.fixture(name="multi_wavelength_snirf_fname", scope="module") -def fixture_multi_wavelength_snirf_fname(tmp_path_factory): - """Return path to a tiny 3-wavelength SNIRF file for io tests.""" - try: - snirf = pytest.importorskip("snirf") - except AttributeError as exc: - # Until https://github.com/BUNPC/pysnirf2/pull/43 is released - pytest.skip(f"snirf import error: {exc}") - out_dir = Path(tmp_path_factory.mktemp("snirf_multi")) - fname = out_dir / "test_multiwl.snirf" - if fname.exists(): - return fname - - # 32 mwasurements with 2 source-detector pairs, each with 3 wavelengths - n_times = 32 - n_freq = 3 - n_channels = 2 * n_freq - - with snirf.Snirf(str(fname), "w") as f: - f.nirs.appendGroup() - f.nirs[0].data.appendGroup() - f.nirs[0].data[0].dataTimeSeries = np.ones((n_times, n_channels)) - f.nirs[0].data[0].time = range(n_times) - for ii in range(n_channels): - f.nirs[0].data[0].measurementList.appendGroup() - f.nirs[0].data[0].measurementList[ii].sourceIndex = ii // n_freq + 1 - f.nirs[0].data[0].measurementList[ii].detectorIndex = ii // n_freq + 1 - f.nirs[0].data[0].measurementList[ii].wavelengthIndex = (ii % 3) + 1 - f.nirs[0].data[0].measurementList[ii].dataType = 1 - f.nirs[0].data[0].measurementList[ii].dataTypeIndex = 0 - f.nirs[0].metaDataTags.SubjectID = "multi" - f.nirs[0].metaDataTags.MeasurementDate = "2000-01-01" - f.nirs[0].metaDataTags.MeasurementTime = "00:00:00" - f.nirs[0].metaDataTags.LengthUnit = "m" - f.nirs[0].metaDataTags.TimeUnit = "s" - f.nirs[0].metaDataTags.FrequencyUnit = "Hz" - f.nirs[0].probe.wavelengths = [700 + x * 30 for x in range(n_freq)] - f.nirs[0].probe.sourcePos3D = [[0.01 * x, 0.0, 0.0] for x in range(n_channels)] - f.nirs[0].probe.detectorPos3D = [ - [0.01 * x, 0.02, 0.0] for x in range(n_channels) - ] - f.save() - - assert fname.exists() - return fname +# Shimadzu Labnirs 3-wavelength converted to snirf using custom tool +labnirs_multi_wavelength = ( + testing_path / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) def _get_loc(raw, ch_name): @@ -142,6 +100,7 @@ def _get_loc(raw, ch_name): nirx_nirsport2_103_2, kernel_hb, lumo110, + labnirs_multi_wavelength, ] ), ) @@ -157,17 +116,6 @@ def test_basic_reading_and_min_process(fname): assert "hbr" in raw -def test_basic_reading_and_min_process_multiwl(multi_wavelength_snirf_fname): - """Ensure synthetic multi-wavelength SNIRF file passes basic processing.""" - raw = read_raw_snirf(multi_wavelength_snirf_fname, preload=True) - assert "fnirs_cw_amplitude" in raw - raw = optical_density(raw) - assert "fnirs_od" in raw - raw = beer_lambert_law(raw, ppf=6) - assert "hbo" in raw - assert "hbr" in raw - - @requires_testing_data @pytest.mark.filterwarnings("ignore:.*measurement date.*:") def test_snirf_gowerlabs(): @@ -641,14 +589,14 @@ def test_sample_rate_jitter(tmp_path): read_raw_snirf(new_file, verbose=True) -def test_snirf_multiple_wavelengths(multi_wavelength_snirf_fname): +def test_snirf_multiple_wavelengths(): """Test importing synthetic SNIRF files with >=3 wavelengths.""" - raw = read_raw_snirf(multi_wavelength_snirf_fname, preload=True) - assert raw._data.shape == (6, 32) - assert raw.info["sfreq"] == pytest.approx(1.0) - assert raw.info["ch_names"][:3] == ["S1_D1 700", "S1_D1 730", "S1_D1 760"] - assert len(raw.ch_names) == 6 + raw = read_raw_snirf(labnirs_multi_wavelength, preload=True) + assert raw._data.shape == (45, 251) + assert raw.info["sfreq"] == pytest.approx(19.6, abs=0.01) + assert raw.info["ch_names"][:3] == ["S2_D2 780", "S2_D2 805", "S2_D2 830"] + assert len(raw.ch_names) == 45 freqs = np.unique(_channel_frequencies(raw.info)) - assert_array_equal(freqs, [700, 730, 760]) + assert_array_equal(freqs, [780, 805, 830]) distances = source_detector_distances(raw.info) assert len(distances) == len(raw.ch_names) From b2336fa7f947f5918d29b3def36efb1a309e0339 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Dec 2025 04:08:49 +0000 Subject: [PATCH 24/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/io/snirf/tests/test_snirf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 40fdf71c6f9..a396e5194a9 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -5,7 +5,6 @@ import datetime import shutil from contextlib import nullcontext -from pathlib import Path import numpy as np import pytest From 77157767a0efb95b7d7adc70f1d8b43843bab691 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Tue, 30 Dec 2025 14:23:47 +0800 Subject: [PATCH 25/30] Fix typo in test_snirf.py, labnirs data has 250 points not 251 --- mne/io/snirf/tests/test_snirf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index a396e5194a9..c8f93180296 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -591,7 +591,7 @@ def test_sample_rate_jitter(tmp_path): def test_snirf_multiple_wavelengths(): """Test importing synthetic SNIRF files with >=3 wavelengths.""" raw = read_raw_snirf(labnirs_multi_wavelength, preload=True) - assert raw._data.shape == (45, 251) + assert raw._data.shape == (45, 250) assert raw.info["sfreq"] == pytest.approx(19.6, abs=0.01) assert raw.info["ch_names"][:3] == ["S2_D2 780", "S2_D2 805", "S2_D2 830"] assert len(raw.ch_names) == 45 From 088fb633e7f477d39d5b6c08e02714421239e795 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Fri, 9 Jan 2026 03:23:28 +0900 Subject: [PATCH 26/30] bumping test data version to 0.171 with MD5 --- mne/datasets/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ca65910dda6..8625e9ee2bd 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.170", + testing="0.171", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:ebd873ea89507cf5a75043f56119d22b", + hash="md5:138caf29bd8a9b0a6b6ea43d92c16201", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" From ca4c1a925c38aa83b6c95bcf96bf757cf9f2cb83 Mon Sep 17 00:00:00 2001 From: Tamas Fehervari <58502181+zEdS15B3GCwq@users.noreply.github.com> Date: Fri, 9 Jan 2026 04:26:19 +0900 Subject: [PATCH 27/30] fix typo in test_beer_lambert.py --- mne/preprocessing/nirs/tests/test_beer_lambert_law.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index aa793ff3bbe..c889237bae9 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -54,7 +54,7 @@ def test_beer_lambert(fname, fmt, tmp_path): raw_volt = read_raw_snirf(fname) case _: raise ValueError( - f"fmt expexted to be one of 'nirx', 'fif' or 'snirf', got {fmt}" + f"fmt expected to be one of 'nirx', 'fif' or 'snirf', got {fmt}" ) raw_od = optical_density(raw_volt) From a34d40acd5677319451f4930324cc9e3bc12990f Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 9 Jan 2026 12:01:58 -0500 Subject: [PATCH 28/30] Update 13408.newfeature.rst --- doc/changes/dev/13408.newfeature.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/changes/dev/13408.newfeature.rst b/doc/changes/dev/13408.newfeature.rst index b53a08ca6ce..c43ebb1144f 100644 --- a/doc/changes/dev/13408.newfeature.rst +++ b/doc/changes/dev/13408.newfeature.rst @@ -1,2 +1 @@ -Add support for multi-wavelength NIRS processing to :func:`mne.preprocessing.nirs.beer_lambert_law`, :func:`mne.preprocessing.nirs.scalp_coupling_index`, and SNIRF reader :class:`mne.io.snirf._snirf.RawSNIRF`, by :newcontrib:`Tamas Fehervari`. -` +Add support for multi-wavelength NIRS processing to :func:`mne.preprocessing.nirs.beer_lambert_law`, :func:`mne.preprocessing.nirs.scalp_coupling_index`, and SNIRF reader :func:`mne.io.read_raw_snirf`, by :newcontrib:`Tamas Fehervari`. From 828cd64929cac5ce11f952ba0e10311ed5994004 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 9 Jan 2026 12:05:07 -0500 Subject: [PATCH 29/30] Update mne/preprocessing/nirs/tests/test_scalp_coupling_index.py --- mne/preprocessing/nirs/tests/test_scalp_coupling_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py index 7035fa5376f..2b4c776c92a 100644 --- a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py @@ -92,7 +92,7 @@ def test_scalp_coupling_index_multi_wavelength(): times = np.arange(raw.n_times) / raw.info["sfreq"] signal = np.sin(2 * np.pi * 1.0 * times) + 1 assert len(raw.ch_names) >= 15 * 3 - rng = np.random.default_rng() + rng = np.random.default_rng(3289745) # pre-determined expected results expected = [] From c115fd06ee5a8cab1fe56c72d8b1dfa07400839a Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 9 Jan 2026 12:23:45 -0500 Subject: [PATCH 30/30] FIX: Paths --- mne/io/snirf/tests/test_snirf.py | 1 + .../nirs/tests/test_scalp_coupling_index.py | 13 +++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index c8f93180296..24f6f1174c7 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -588,6 +588,7 @@ def test_sample_rate_jitter(tmp_path): read_raw_snirf(new_file, verbose=True) +@requires_testing_data def test_snirf_multiple_wavelengths(): """Test importing synthetic SNIRF files with >=3 wavelengths.""" raw = read_raw_snirf(labnirs_multi_wavelength, preload=True) diff --git a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py index 2b4c776c92a..832a1158486 100644 --- a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py @@ -15,17 +15,14 @@ scalp_coupling_index, ) -fname_nirx_15_0 = ( - data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_0_recording" -) -fname_nirx_15_2 = ( - data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording" -) +testing_path = data_path(download=False) +fname_nirx_15_0 = testing_path / "NIRx" / "nirscout" / "nirx_15_0_recording" +fname_nirx_15_2 = testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording" fname_nirx_15_2_short = ( - data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" + testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) fname_labnirs_multi_wavelength = ( - data_path(download=False) / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" + testing_path / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" )