diff --git a/doc/changes/dev/13408.newfeature.rst b/doc/changes/dev/13408.newfeature.rst new file mode 100644 index 00000000000..c43ebb1144f --- /dev/null +++ b/doc/changes/dev/13408.newfeature.rst @@ -0,0 +1 @@ +Add support for multi-wavelength NIRS processing to :func:`mne.preprocessing.nirs.beer_lambert_law`, :func:`mne.preprocessing.nirs.scalp_coupling_index`, and SNIRF reader :func:`mne.io.read_raw_snirf`, by :newcontrib:`Tamas Fehervari`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 1a9c62b9bea..5a5a1e98f89 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -320,6 +320,7 @@ .. _Sébastien Marti: https://www.researchgate.net/profile/Sebastien-Marti .. _T. Wang: https://github.com/twang5 .. _Tal Linzen: https://tallinzen.net/ +.. _Tamas Fehervari: https://github.com/zEdS15B3GCwq .. _Teon Brooks: https://github.com/teonbrooks .. _Tharupahan Jayawardana: https://github.com/tharu-jwd .. _Thomas Binns: https://github.com/tsbinns diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ca65910dda6..8625e9ee2bd 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.170", + testing="0.171", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:ebd873ea89507cf5a75043f56119d22b", + hash="md5:138caf29bd8a9b0a6b6ea43d92c16201", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" diff --git a/mne/io/snirf/_snirf.py b/mne/io/snirf/_snirf.py index 55f3c54605c..cfb9c8f71e2 100644 --- a/mne/io/snirf/_snirf.py +++ b/mne/io/snirf/_snirf.py @@ -148,14 +148,13 @@ def __init__( # Extract wavelengths fnirs_wavelengths = np.array(dat.get("nirs/probe/wavelengths")) fnirs_wavelengths = [int(w) for w in fnirs_wavelengths] - if len(fnirs_wavelengths) != 2: + if len(fnirs_wavelengths) < 2: raise RuntimeError( f"The data contains " f"{len(fnirs_wavelengths)}" f" wavelengths: {fnirs_wavelengths}. " - f"MNE only supports reading continuous" - " wave amplitude SNIRF files " - "with two wavelengths." + f"MNE requires at least two wavelengths for " + "continuous wave amplitude SNIRF files." ) # Extract channels diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 73e3c775ed1..24f6f1174c7 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -8,13 +8,19 @@ import numpy as np import pytest -from numpy.testing import assert_allclose, assert_almost_equal, assert_equal +from numpy.testing import ( + assert_allclose, + assert_almost_equal, + assert_array_equal, + assert_equal, +) from mne._fiff.constants import FIFF from mne.datasets.testing import data_path, requires_testing_data from mne.io import read_raw_nirx, read_raw_snirf from mne.io.tests.test_raw import _test_raw_reader from mne.preprocessing.nirs import ( + _channel_frequencies, _reorder_nirx, beer_lambert_law, optical_density, @@ -68,6 +74,11 @@ # GowerLabs lumo110 = testing_path / "SNIRF" / "GowerLabs" / "lumomat-1-1-0.snirf" +# Shimadzu Labnirs 3-wavelength converted to snirf using custom tool +labnirs_multi_wavelength = ( + testing_path / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) + def _get_loc(raw, ch_name): return raw.copy().pick(ch_name).info["chs"][0]["loc"] @@ -88,6 +99,7 @@ def _get_loc(raw, ch_name): nirx_nirsport2_103_2, kernel_hb, lumo110, + labnirs_multi_wavelength, ] ), ) @@ -574,3 +586,17 @@ def test_sample_rate_jitter(tmp_path): f.create_dataset("nirs/data1/time", data=unacceptable_time_jitter) with pytest.warns(RuntimeWarning, match="non-uniformly-sampled data"): read_raw_snirf(new_file, verbose=True) + + +@requires_testing_data +def test_snirf_multiple_wavelengths(): + """Test importing synthetic SNIRF files with >=3 wavelengths.""" + raw = read_raw_snirf(labnirs_multi_wavelength, preload=True) + assert raw._data.shape == (45, 250) + assert raw.info["sfreq"] == pytest.approx(19.6, abs=0.01) + assert raw.info["ch_names"][:3] == ["S2_D2 780", "S2_D2 805", "S2_D2 830"] + assert len(raw.ch_names) == 45 + freqs = np.unique(_channel_frequencies(raw.info)) + assert_array_equal(freqs, [780, 805, 830]) + distances = source_detector_distances(raw.info) + assert len(distances) == len(raw.ch_names) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index c17cf31110c..0c1482ef3b8 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -11,7 +11,7 @@ from ..._fiff.constants import FIFF from ...io import BaseRaw from ...utils import _validate_type, pinv, warn -from ..nirs import _validate_nirs_info, source_detector_distances +from ..nirs import _channel_frequencies, _validate_nirs_info, source_detector_distances def beer_lambert_law(raw, ppf=6.0): @@ -36,23 +36,46 @@ def beer_lambert_law(raw, ppf=6.0): _validate_type(raw, BaseRaw, "raw") _validate_type(ppf, ("numeric", "array-like"), "ppf") ppf = np.array(ppf, float) - if ppf.ndim == 0: # upcast single float to shape (2,) - ppf = np.array([ppf, ppf]) - if ppf.shape != (2,): + picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert") + + # Use nominal channel frequencies + # + # Notes on implementation: + # 1. Frequencies are calculated the same way as in nirs._validate_nirs_info(). + # 2. Wavelength values in the info structure may contain actual frequencies, + # which may be used for more accurate calculation in the future. + # 3. nirs._channel_frequencies uses both cw_amplitude and OD data to determine + # frequencies, whereas we only need those from OD here. Is there any chance + # that they're different? + # 4. If actual frequencies were used, using np.unique() like below will lead to + # errors. Instead, absorption coefficients will need to be calculated for + # each individual frequency. + freqs = _channel_frequencies(raw.info) + + # Get unique wavelengths and determine number of wavelengths + unique_freqs = np.unique(freqs) + n_wavelengths = len(unique_freqs) + + # PPF validation for multiple wavelengths + if ppf.ndim == 0: # single float + # same PPF for all wavelengths, shape (n_wavelengths, 1) + ppf = np.full((n_wavelengths, 1), ppf) + elif ppf.ndim == 1 and len(ppf) == n_wavelengths: + # separate ppf for each wavelength + ppf = ppf[:, np.newaxis] # shape (n_wavelengths, 1) + else: raise ValueError( - f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}" + f"ppf must be a single float or an array-like of length {n_wavelengths} " + f"(number of wavelengths), got shape {ppf.shape}" ) - ppf = ppf[:, np.newaxis] # shape (2, 1) - picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert") - # This is the one place we *really* need the actual/accurate frequencies - freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) - abs_coef = _load_absorption(freqs) + + abs_coef = _load_absorption(unique_freqs) # shape (n_wavelengths, 2) distances = source_detector_distances(raw.info, picks="all") bad = ~np.isfinite(distances[picks]) bad |= distances[picks] <= 0 if bad.any(): warn( - "Source-detector distances are zero on NaN, some resulting " + "Source-detector distances are zero or NaN, some resulting " "concentrations will be zero. Consider setting a montage " "with raw.set_montage." ) @@ -64,20 +87,41 @@ def beer_lambert_law(raw, ppf=6.0): "likely due to optode locations being stored in a " " unit other than meters." ) + rename = dict() - for ii, jj in zip(picks[::2], picks[1::2]): - EL = abs_coef * distances[ii] * ppf + channels_to_drop_all = [] # Accumulate all channels to drop + + # Iterate over channel groups ([Si_Di all wavelengths, Sj_Dj all wavelengths, ...]) + for ii in range(0, len(picks), n_wavelengths): + group_picks = picks[ii : ii + n_wavelengths] + # Calculate Δc based on the system: ΔOD = E * L * PPF * Δc + # where E is (n_wavelengths, 2), Δc is (2, n_timepoints) + # using pseudo-inverse + EL = abs_coef * distances[group_picks[0]] * ppf iEL = pinv(EL) + conc_data = iEL @ raw._data[group_picks] * 1e-3 - raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3 + # Replace the first two channels with HbO and HbR + raw._data[group_picks[:2]] = conc_data[:2] # HbO, HbR # Update channel information coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO, hbr=FIFF.FIFFV_COIL_FNIRS_HBR) - for ki, kind in zip((ii, jj), ("hbo", "hbr")): + for ki, kind in zip(group_picks[:2], ("hbo", "hbr")): ch = raw.info["chs"][ki] ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL) new_name = f"{ch['ch_name'].split(' ')[0]} {kind}" rename[ch["ch_name"]] = new_name + + # Accumulate extra wavelength channels to drop (keep only HbO and HbR) + if n_wavelengths > 2: + channels_to_drop = group_picks[2:] + channel_names_to_drop = [raw.ch_names[idx] for idx in channels_to_drop] + channels_to_drop_all.extend(channel_names_to_drop) + + # Drop all accumulated extra wavelength channels after processing all groups + if channels_to_drop_all: + raw.drop_channels(channels_to_drop_all) + raw.rename_channels(rename) # Validate the format of data after transformation is valid @@ -95,7 +139,9 @@ def _load_absorption(freqs): # save('extinction_coef.mat', 'extinct_coef') # # Returns data as [[HbO2(freq1), Hb(freq1)], - # [HbO2(freq2), Hb(freq2)]] + # [HbO2(freq2), Hb(freq2)], + # ..., + # [HbO2(freqN), Hb(freqN)]] extinction_fname = op.join( op.dirname(__file__), "..", "..", "data", "extinction_coef.mat" ) @@ -104,12 +150,12 @@ def _load_absorption(freqs): interp_hbo = interp1d(a[:, 0], a[:, 1], kind="linear") interp_hb = interp1d(a[:, 0], a[:, 2], kind="linear") - ext_coef = np.array( - [ - [interp_hbo(freqs[0]), interp_hb(freqs[0])], - [interp_hbo(freqs[1]), interp_hb(freqs[1])], - ] - ) - abs_coef = ext_coef * 0.2303 + # Build coefficient matrix for all wavelengths + # Shape: (n_wavelengths, 2) where columns are [HbO2, Hb] + ext_coef = np.zeros((len(freqs), 2)) + for i, freq in enumerate(freqs): + ext_coef[i, 0] = interp_hbo(freq) # HbO2 + ext_coef[i, 1] = interp_hb(freq) # Hb + abs_coef = ext_coef * 0.2303 return abs_coef diff --git a/mne/preprocessing/nirs/_scalp_coupling_index.py b/mne/preprocessing/nirs/_scalp_coupling_index.py index 5a82664dfd1..062af0be8aa 100644 --- a/mne/preprocessing/nirs/_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/_scalp_coupling_index.py @@ -6,7 +6,7 @@ from ...io import BaseRaw from ...utils import _validate_type, verbose -from ..nirs import _validate_nirs_info +from ..nirs import _channel_frequencies, _validate_nirs_info @verbose @@ -56,14 +56,34 @@ def scalp_coupling_index( verbose=verbose, ).get_data() + # Determine number of wavelengths per source-detector pair + # We use nominal wavelengths as the info structure may contain arbitrary data. + freqs = _channel_frequencies(raw.info) + n_wavelengths = len(np.unique(freqs)) + sci = np.zeros(picks.shape) - for ii in range(0, len(picks), 2): - with np.errstate(invalid="ignore"): - c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1] - if not np.isfinite(c): # someone had std=0 - c = 0 - sci[ii] = c - sci[ii + 1] = c + + # Calculate all pairwise correlations within each group and use the minimum as SCI + pair_indices = np.triu_indices(n_wavelengths, k=1) + + for gg in range(0, len(picks), n_wavelengths): + group_data = filtered_data[gg : gg + n_wavelengths] + + # Calculate pairwise correlations within the group + correlations = np.zeros(pair_indices[0].shape[0]) + + for n, (ii, jj) in enumerate(zip(*pair_indices)): + with np.errstate(invalid="ignore"): + c = np.corrcoef(group_data[ii], group_data[jj])[0][1] + if np.isfinite(c): + correlations[n] = c + + # Use minimum correlation as SCI + group_sci = correlations.min() + + # Assign the same SCI value to all channels in the group + sci[gg : gg + n_wavelengths] = group_sci + sci[zero_mask] = 0 sci = sci[np.argsort(picks)] # restore original order return sci diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index 94c7c78468c..49827fd1df5 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -104,7 +104,7 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # All chromophore fNIRS data picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True) - if (len(picks_wave) > 0) & (len(picks_chroma) > 0): + if (len(picks_wave) > 0) and (len(picks_chroma) > 0): picks = _throw_or_return_empty( "MNE does not support a combination of amplitude, optical " "density, and haemoglobin data in the same raw structure.", @@ -122,19 +122,18 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr picks = picks_chroma pair_vals = np.array(pair_vals) - if pair_vals.shape != (2,): + if pair_vals.shape[0] < 2: raise ValueError( - f"Exactly two {error_word} must exist in info, got {list(pair_vals)}" + f"At least two {error_word} must exist in info, got {list(pair_vals)}" ) # In principle we do not need to require that these be sorted -- # all we need to do is change our sorted() below to make use of a # pair_vals.index(...) in a sort key -- but in practice we always want - # (hbo, hbr) or (lower_freq, upper_freq) pairings, both of which will + # (hbo, hbr) or (lowest_freq, higher_freq, ...) pairings, both of which will # work with a naive string sort, so let's just enforce sorted-ness here is_str = pair_vals.dtype.kind == "U" - pair_vals = list(pair_vals) if is_str: - if pair_vals != ["hbo", "hbr"]: + if pair_vals.tolist() != ["hbo", "hbr"]: raise ValueError( f'The {error_word} in info must be ["hbo", "hbr"], but got ' f"{pair_vals} instead" @@ -145,22 +144,28 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr f"got {pair_vals} instead" ) - if len(picks) % 2 != 0: + # Check that the total number of channels is divisible by the number of pair values + # (e.g., for 2 wavelengths, we need an even number of channels) + if len(picks) % len(pair_vals) != 0: picks = _throw_or_return_empty( - "NIRS channels not ordered correctly. An even number of NIRS " - f"channels is required. {len(info.ch_names)} channels were" - f"provided", + "NIRS channels not ordered correctly. The number of channels " + f"must be a multiple of {len(pair_vals)} values, but " + f"{len(picks)} channels were provided.", throw_errors, ) # Ensure wavelength info exists for waveform data - all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] - if np.any(np.isnan(all_freqs)): - picks = _throw_or_return_empty( - f"NIRS channels is missing wavelength information in the " - f'info["chs"] structure. The encoded wavelengths are {all_freqs}.', - throw_errors, - ) + # Note: currently, the only requirement for the wavelength field in info is + # that it cannot be NaN. It depends on the data readers what is stored in it. + if len(picks_wave) > 0: + all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] + # test for nan values first as those mess up the output of set() + if np.any(np.isnan(all_freqs)): + picks = _throw_or_return_empty( + f"NIRS channels is missing wavelength information in the " + f'info["chs"] structure. The encoded wavelengths are {all_freqs}.', + throw_errors, + ) # Validate the channel naming scheme for pick in picks: @@ -174,8 +179,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr ) break value = ch_name_info.groups()[2] - if len(picks_wave): - value = value + if len(picks_wave) > 0: + pass else: # picks_chroma if value not in ["hbo", "hbr"]: picks = _throw_or_return_empty( @@ -189,40 +194,51 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr # Reorder to be paired (naive sort okay here given validation above) picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])] - # Validate our paired ordering - for ii, jj in zip(picks[::2], picks[1::2]): - ch1_name = info["chs"][ii]["ch_name"] - ch2_name = info["chs"][jj]["ch_name"] - ch1_re = use_RE.match(ch1_name) - ch2_re = use_RE.match(ch2_name) - ch1_S, ch1_D, ch1_value = ch1_re.groups()[:3] - ch2_S, ch2_D, ch2_value = ch2_re.groups()[:3] - if len(picks_wave): - ch1_value, ch2_value = float(ch1_value), float(ch2_value) - if ( - (ch1_S != ch2_S) - or (ch1_D != ch2_D) - or (ch1_value != pair_vals[0]) - or (ch2_value != pair_vals[1]) + # Validate channel grouping (same source-detector pairs, all pair_vals match) + for ii in range(0, len(picks), len(pair_vals)): + # Extract a group of channels (e.g., all wavelengths for one S-D pair) + group_picks = picks[ii : ii + len(pair_vals)] + + # Parse channel names using regex to extract source, detector, and value info + group_info = [ + (use_RE.match(info["ch_names"][pick]).groups() or (pick, 0, 0)) + for pick in group_picks + ] + + # Separate the parsed components: + # source IDs, detector IDs, and values (freq/chromophore) + s_group, d_group, val_group = zip(*group_info) + + # For wavelength data, convert string frequencies to float for comparison + if len(picks_wave) > 0: + val_group = np.array([float(v) for v in val_group]) + + # Verify that all channels in this group have the same source-detector pair + # and that the values match the expected pair_vals sequence + if not ( + len(set(s_group)) == 1 + and len(set(d_group)) == 1 + and np.array_equal(val_group, pair_vals) ): picks = _throw_or_return_empty( "NIRS channels not ordered correctly. Channels must be " - "ordered as source detector pairs with alternating" - f" {error_word} {pair_vals[0]} & {pair_vals[1]}, but got " - f"S{ch1_S}_D{ch1_D} pair " - f"{repr(ch1_name)} and {repr(ch2_name)}", + f"grouped by source-detector pairs with alternating {error_word} " + f"values {pair_vals}, but got mismatching names " + f"{[info['ch_names'][pick] for pick in group_picks]}.", throw_errors, ) break if check_bads: - for ii, jj in zip(picks[::2], picks[1::2]): - want = [info.ch_names[ii], info.ch_names[jj]] + for ii in range(0, len(picks), len(pair_vals)): + group_picks = picks[ii : ii + len(pair_vals)] + + want = [info.ch_names[pick] for pick in group_picks] got = list(set(info["bads"]).intersection(want)) - if len(got) == 1: + if 0 < len(got) < len(want): raise RuntimeError( - f"NIRS bad labelling is not consistent, found {got} but " - f"needed {want}" + "NIRS bad labelling is not consistent. " + f"Found {got} but needed {want}. " ) return picks @@ -276,14 +292,29 @@ def _fnirs_spread_bads(info): # as bad and spread the bad marking to all components of the optode pair. picks = _validate_nirs_info(info, check_bads=False) new_bads = set(info["bads"]) - for ii, jj in zip(picks[::2], picks[1::2]): - ch1_name, ch2_name = info.ch_names[ii], info.ch_names[jj] - if ch1_name in new_bads: - new_bads.add(ch2_name) - elif ch2_name in new_bads: - new_bads.add(ch1_name) - info["bads"] = sorted(new_bads) + # Extract SD pair groups from channel names + # E.g. all channels belonging to S1D1, S1D2, etc. + # Assumes valid channels (naming convention and number) + ch_names = [info.ch_names[i] for i in picks] + match = re.compile(r"^(S\d+_D\d+) ") + + # Create dict with keys corresponding to SD pairs + # Defaultdict would require another import + sd_groups = {} + for ch_name in ch_names: + sd_pair = match.match(ch_name).group(1) + if sd_pair not in sd_groups: + sd_groups[sd_pair] = [ch_name] + else: + sd_groups[sd_pair].append(ch_name) + + # Spread bad labeling across SD pairs + for channels in sd_groups.values(): + if any(channel in new_bads for channel in channels): + new_bads.update(channels) + + info["bads"] = sorted(new_bads) return info diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index 5768ff038ab..c889237bae9 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -7,8 +7,12 @@ from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import BaseRaw, read_raw_fif, read_raw_nirx -from mne.preprocessing.nirs import beer_lambert_law, optical_density +from mne.io import BaseRaw, read_raw_fif, read_raw_nirx, read_raw_snirf +from mne.preprocessing.nirs import ( + _channel_frequencies, + beer_lambert_law, + optical_density, +) from mne.utils import _validate_type testing_path = data_path(download=False) @@ -17,56 +21,67 @@ fname_nirx_15_2_short = ( testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) +fname_labnirs_multi_wavelength = ( + testing_path / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname,fmt", + ( + [ + (fname_nirx_15_2_short, "nirx"), + (fname_nirx_15_2_short, "fif"), + (fname_nirx_15_2, "nirx"), + (fname_nirx_15_2, "fif"), + (fname_nirx_15_0, "nirx"), + (fname_nirx_15_0, "fif"), + (fname_labnirs_multi_wavelength, "snirf"), + ] + ), ) -@pytest.mark.parametrize("fmt", ("nirx", "fif")) def test_beer_lambert(fname, fmt, tmp_path): - """Test converting NIRX files.""" - assert fmt in ("nirx", "fif") - raw = read_raw_nirx(fname) - if fmt == "fif": - raw.save(tmp_path / "test_raw.fif") - raw = read_raw_fif(tmp_path / "test_raw.fif") - assert "fnirs_cw_amplitude" in raw - assert "fnirs_od" not in raw - raw = optical_density(raw) - _validate_type(raw, BaseRaw, "raw") - assert "fnirs_cw_amplitude" not in raw - assert "fnirs_od" in raw - assert "hbo" not in raw - raw = beer_lambert_law(raw) - _validate_type(raw, BaseRaw, "raw") - assert "fnirs_cw_amplitude" not in raw - assert "fnirs_od" not in raw - assert "hbo" in raw - assert "hbr" in raw + """Test converting raw CW amplitude files.""" + match fmt: + case "nirx": + raw_volt = read_raw_nirx(fname) + case "fif": + raw_nirx = read_raw_nirx(fname) + raw_nirx.save(tmp_path / "test_raw.fif") + raw_volt = read_raw_fif(tmp_path / "test_raw.fif") + case "snirf": + raw_volt = read_raw_snirf(fname) + case _: + raise ValueError( + f"fmt expected to be one of 'nirx', 'fif' or 'snirf', got {fmt}" + ) + raw_od = optical_density(raw_volt) + _validate_type(raw_od, BaseRaw, "raw") -@testing.requires_testing_data -def test_beer_lambert_unordered_errors(): - """NIRS data requires specific ordering and naming of channels.""" - raw = read_raw_nirx(fname_nirx_15_0) - raw_od = optical_density(raw) - raw_od.pick([0, 1, 2]) - with pytest.raises(ValueError, match="ordered"): - beer_lambert_law(raw_od) + raw_hb = beer_lambert_law(raw_od) + _validate_type(raw_hb, BaseRaw, "raw") + + # Verify channel numbers (multi-wavelength aware) + # Raw voltage has: optode pairs * number of wavelengths + # OD must have the same number as raw voltage + # Hb data must have: number of optode pairs * 2 + nfreqs = len(set(_channel_frequencies(raw_volt.info))) + assert len(raw_volt.ch_names) % nfreqs == 0 + npairs = len(raw_volt.ch_names) // nfreqs + assert len(raw_hb.ch_names) % npairs == 0 + assert len(raw_hb.ch_names) // npairs == 2.0 + + # Verify data types + assert set(raw_volt.get_channel_types()) == {"fnirs_cw_amplitude"} + assert set(raw_hb.get_channel_types()) == {"hbo", "hbr"} - # Test that an error is thrown if channel naming frequency doesn't match - # what is stored in loc[9], which should hold the light frequency too. - raw_od = optical_density(raw) - ch_name = raw.ch_names[0] - assert ch_name == "S1_D1 760" - idx = raw_od.ch_names.index(ch_name) - assert idx == 0 - raw_od.info["chs"][idx]["loc"][9] = 770 - raw_od.rename_channels({ch_name: ch_name.replace("760", "770")}) - assert raw_od.ch_names[0] == "S1_D1 770" - with pytest.raises(ValueError, match="Exactly two frequencies"): - beer_lambert_law(raw_od) + # Verify that pair ordering did not change just channel name suffixes + old_prefixes = [name.split(" ")[0] for name in raw_volt.ch_names[::nfreqs]] + new_prefixes = [name.split(" ")[0] for name in raw_hb.ch_names[::2]] + assert old_prefixes == new_prefixes + assert all([name.split(" ")[1] in {"hbo", "hbr"} for name in raw_hb.ch_names]) @testing.requires_testing_data diff --git a/mne/preprocessing/nirs/tests/test_nirs.py b/mne/preprocessing/nirs/tests/test_nirs.py index 89fa17c0c8d..069657e2501 100644 --- a/mne/preprocessing/nirs/tests/test_nirs.py +++ b/mne/preprocessing/nirs/tests/test_nirs.py @@ -11,7 +11,7 @@ from mne._fiff.pick import _picks_to_idx from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import RawArray, read_raw_nirx +from mne.io import RawArray, read_raw_nirx, read_raw_snirf from mne.preprocessing.nirs import ( _channel_chromophore, _channel_frequencies, @@ -35,12 +35,22 @@ fname_nirx_15_2_short = ( data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) +fname_labnirs_multi_wavelength = ( + data_path(download=False) / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) @testing.requires_testing_data -def test_fnirs_picks(): +@pytest.mark.parametrize( + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], +) +def test_fnirs_picks(fname, readerfn): """Test picking of fnirs types after different conversions.""" - raw = read_raw_nirx(fname_nirx_15_0) + raw = readerfn(fname) picks = _picks_to_idx(raw.info, "fnirs_cw_amplitude") assert len(picks) == len(raw.ch_names) raw_subset = raw.copy().pick(picks="fnirs_cw_amplitude") @@ -106,12 +116,18 @@ def _fnirs_check_bads(info): @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], ) -def test_fnirs_check_bads(fname): +def test_fnirs_check_bads(fname, readerfn): """Test checking of bad markings.""" # No bad channels, so these should all pass - raw = read_raw_nirx(fname) + raw = readerfn(fname) _fnirs_check_bads(raw.info) raw = optical_density(raw) _fnirs_check_bads(raw.info) @@ -119,8 +135,9 @@ def test_fnirs_check_bads(fname): _fnirs_check_bads(raw.info) # Mark pairs of bad channels, so these should all pass - raw = read_raw_nirx(fname) - raw.info["bads"] = raw.ch_names[0:2] + raw = readerfn(fname) + nfreqs = len(set(_channel_frequencies(raw.info))) + raw.info["bads"] = raw.ch_names[0:nfreqs] _fnirs_check_bads(raw.info) raw = optical_density(raw) _fnirs_check_bads(raw.info) @@ -128,7 +145,7 @@ def test_fnirs_check_bads(fname): _fnirs_check_bads(raw.info) # Mark single channel as bad, so these should all fail - raw = read_raw_nirx(fname) + raw = readerfn(fname) raw.info["bads"] = raw.ch_names[0:1] pytest.raises(RuntimeError, _fnirs_check_bads, raw.info) with pytest.raises(RuntimeError, match="bad labelling"): @@ -144,71 +161,90 @@ def test_fnirs_check_bads(fname): @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], ) -def test_fnirs_spread_bads(fname): +def test_fnirs_spread_bads(fname, readerfn): """Test checking of bad markings.""" # Test spreading upwards in frequency and on raw data - raw = read_raw_nirx(fname) - raw.info["bads"] = ["S1_D1 760"] + raw = readerfn(fname) + nfreqs = len(set(_channel_frequencies(raw.info))) + raw.info["bads"] = [raw.ch_names[0]] info = _fnirs_spread_bads(raw.info) - assert info["bads"] == ["S1_D1 760", "S1_D1 850"] - - # Test spreading downwards in frequency and on od data - raw = optical_density(raw) - raw.info["bads"] = raw.ch_names[5:6] - info = _fnirs_spread_bads(raw.info) - assert info["bads"] == raw.ch_names[4:6] + assert info["bads"] == raw.ch_names[:nfreqs] + + # Test multiple spreading directions on od data + # For each wavelength, mark the nth item in the nth group as bad + # e.g. 3 wavelengths: + # group 0 (channels 0, 1, 2) - channel 0 is marked bad, + # group 1 (channels 3, 4, 5) - channel 4 is bad, + # group 2 (channels 6, 7, 8) - channel 8 is bad. + # This way we can test spreading from each group member in a + # way that's agnostic of the number of wavelengths, and avoid + # hard-coded values. Needs nfreqs**2 number of channels. + raw_od = optical_density(raw) + bads = [raw_od.ch_names[nfreqs * ii + ii] for ii in range(nfreqs)] + expected_bads = raw_od.ch_names[: nfreqs**2] + raw_od.info["bads"] = bads + info = _fnirs_spread_bads(raw_od.info) + # channels might not be sorted but the spreading result is + assert info["bads"] == sorted(expected_bads) # Test spreading multiple bads and on chroma data - raw = beer_lambert_law(raw) - raw.info["bads"] = [raw.ch_names[x] for x in [1, 8]] - info = _fnirs_spread_bads(raw.info) + # Hb data always has 2 channels per S-D pair, this works for any nfreqs + raw_hb = beer_lambert_law(raw_od) + raw_hb.info["bads"] = [raw_hb.ch_names[x] for x in [1, 8]] + info = _fnirs_spread_bads(raw_hb.info) assert info["bads"] == [info.ch_names[x] for x in [0, 1, 8, 9]] @testing.requires_testing_data @pytest.mark.parametrize( - "fname", ([fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0]) + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], ) -def test_fnirs_channel_naming_and_order_readers(fname): - """Ensure fNIRS channel checking on standard readers.""" +def test_fnirs_channel_frequency_ordering(fname, readerfn): + """Test fNIRS channel frequencies ordering and related errors.""" # fNIRS data requires specific channel naming and ordering. - # All standard readers should pass tests - raw = read_raw_nirx(fname) - freqs = np.unique(_channel_frequencies(raw.info)) - assert_array_equal(freqs, [760, 850]) + # Ensure that freqs are well-ordered after reading in from file, + # and that there are no chroma channels + raw = readerfn(fname) + freqs = np.unique(_channel_frequencies(raw.info)).tolist() + assert sorted(freqs) == freqs chroma = np.unique(_channel_chromophore(raw.info)) assert len(chroma) == 0 - picks = _check_channels_ordered(raw.info, freqs) assert len(picks) == len(raw.ch_names) # as all fNIRS only data - # Check that dropped channels are detected - # For each source detector pair there must be two channels, - # removing one should throw an error. - raw_dropped = raw.copy().drop_channels(raw.ch_names[4]) - with pytest.raises(ValueError, match="not ordered correctly"): - _check_channels_ordered(raw_dropped.info, freqs) - # The ordering must be increasing for the pairs, if provided - raw_names_reversed = raw.copy().ch_names - raw_names_reversed.reverse() - raw_reversed = raw.copy().pick(raw_names_reversed) + raw_reversed = raw.pick(list(reversed(raw.ch_names))) + with pytest.raises(ValueError, match="The frequencies.*sorted.*"): - _check_channels_ordered(raw_reversed.info, [850, 760]) + _check_channels_ordered(raw_reversed.info, list(reversed(freqs))) # So if we flip the second argument it should pass again picks = _check_channels_ordered(raw_reversed.info, freqs) - got_first = set(raw_reversed.ch_names[pick].split()[1] for pick in picks[::2]) - assert got_first == {"760"} - got_second = set(raw_reversed.ch_names[pick].split()[1] for pick in picks[1::2]) - assert got_second == {"850"} + nfreqs = len(freqs) + for ii in range(nfreqs): + suffixes = { + int(raw_reversed.ch_names[pick].split(" ")[1]) for pick in picks[ii::nfreqs] + } + assert suffixes == {freqs[ii]} # Check on OD data raw = optical_density(raw) - freqs = np.unique(_channel_frequencies(raw.info)) - assert_array_equal(freqs, [760, 850]) + freqs = np.unique(_channel_frequencies(raw.info)).tolist() + assert sorted(freqs) == freqs chroma = np.unique(_channel_chromophore(raw.info)) assert len(chroma) == 0 picks = _check_channels_ordered(raw.info, freqs) @@ -553,3 +589,58 @@ def test_order_agnostic(nirx_snirf): tddrs["nirx"].get_data(), r.get_data(orders[key]), err_msg=key, atol=1e-9 ) assert set(r.get_channel_types()) == {"hbo", "hbr"} + + +@testing.requires_testing_data +@pytest.mark.parametrize( + "fname, readerfn", + [ + (fname_nirx_15_0, read_raw_nirx), + (fname_nirx_15_2_short, read_raw_nirx), + (fname_nirx_15_2, read_raw_nirx), + (fname_labnirs_multi_wavelength, read_raw_snirf), + ], +) +def test_nirs_channel_grouping(fname, readerfn): + """Test channel grouping related errors.""" + raw = readerfn(fname) + freqs = np.unique(_channel_frequencies(raw.info)).tolist() + nfreqs = len(freqs) + + # Each source-detector (S-D) optode pair may have data measured at + # >=2 wavelengths. The channels that belong to an optode pair form + # a group, with a size equal to the number of wavelengths (nfreqs). + # An error is raised if these groups are incomplete. + + picks = _check_channels_ordered(raw.info, freqs) + assert len(picks) == len(raw.ch_names) + + # Removing one channel breaks the grouping + raw_dropped = raw.copy().drop_channels(raw.ch_names[4]) + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_dropped.info, freqs) + + # Selecting incomplete groups results in an error + raw_incomplete = raw.copy().pick(list(range(nfreqs + 1))) + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_incomplete.info, freqs) + + # Changing the frequency in one channel's name also breaks groups + raw_extrafreq = raw.copy() + new_ch10_name = f"{raw_extrafreq.ch_names[10].split(' ')[0]} 100" + raw_extrafreq.rename_channels({raw_extrafreq.ch_names[10]: new_ch10_name}) + print(raw_extrafreq.ch_names) + # checks result in error for both old and new set of frequencies + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_extrafreq.info, [100] + freqs) + with pytest.raises(ValueError, match="NIRS channels not ordered correctly."): + _check_channels_ordered(raw_extrafreq.info, freqs) + + # Frequency values are also stored in info['chs'][ii]['loc'][9]. + # Even though the actual values can be arbitrary, they cannot be None. + raw_locnone = raw.copy() + raw_locnone.info["chs"][10]["loc"][9] = None + with pytest.raises( + ValueError, match="NIRS channels is missing wavelength information" + ): + _check_channels_ordered(raw_locnone.info, freqs) diff --git a/mne/preprocessing/nirs/tests/test_optical_density.py b/mne/preprocessing/nirs/tests/test_optical_density.py index 89b9edce713..53b66f46238 100644 --- a/mne/preprocessing/nirs/tests/test_optical_density.py +++ b/mne/preprocessing/nirs/tests/test_optical_density.py @@ -8,27 +8,42 @@ from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import BaseRaw, read_raw_nirx +from mne.io import BaseRaw, read_raw_nirx, read_raw_snirf from mne.preprocessing.nirs import optical_density from mne.utils import _validate_type fname_nirx = ( data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) +fname_labnirs_multi_wavelength = ( + data_path(download=False) / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" +) @testing.requires_testing_data -def test_optical_density(): +@pytest.mark.parametrize( + "fname,readerfn", + [(fname_nirx, read_raw_nirx), (fname_labnirs_multi_wavelength, read_raw_snirf)], +) +def test_optical_density(fname, readerfn): """Test return type for optical density.""" - raw = read_raw_nirx(fname_nirx, preload=False) - assert "fnirs_cw_amplitude" in raw - assert "fnirs_od" not in raw - raw = optical_density(raw) - _validate_type(raw, BaseRaw, "raw") - assert "fnirs_cw_amplitude" not in raw - assert "fnirs_od" in raw + raw_volt = readerfn(fname, preload=False) + _validate_type(raw_volt, BaseRaw, "raw") + + raw_od = optical_density(raw_volt) + _validate_type(raw_od, BaseRaw, "raw") + + # Verify data types + assert set(raw_volt.get_channel_types()) == {"fnirs_cw_amplitude"} + assert set(raw_od.get_channel_types()) == {"fnirs_od"} + + # Verify that channel names did not change + for oldname, newname in zip(raw_volt.ch_names, raw_od.ch_names): + assert oldname == newname + + # Cannot run OD conversion on OD data with pytest.raises(RuntimeError, match="on continuous wave"): - optical_density(raw) + optical_density(raw_od) @testing.requires_testing_data diff --git a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py index 4a8fd3e71d0..832a1158486 100644 --- a/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py +++ b/mne/preprocessing/nirs/tests/test_scalp_coupling_index.py @@ -8,21 +8,21 @@ from mne.datasets import testing from mne.datasets.testing import data_path -from mne.io import read_raw_nirx +from mne.io import read_raw_nirx, read_raw_snirf from mne.preprocessing.nirs import ( beer_lambert_law, optical_density, scalp_coupling_index, ) -fname_nirx_15_0 = ( - data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_0_recording" -) -fname_nirx_15_2 = ( - data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording" -) +testing_path = data_path(download=False) +fname_nirx_15_0 = testing_path / "NIRx" / "nirscout" / "nirx_15_0_recording" +fname_nirx_15_2 = testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording" fname_nirx_15_2_short = ( - data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" + testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" +) +fname_labnirs_multi_wavelength = ( + testing_path / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf" ) @@ -76,3 +76,84 @@ def test_scalp_coupling_index(fname, fmt, tmp_path): raw = beer_lambert_law(raw, ppf=6) with pytest.raises(RuntimeError, match="Scalp"): scalp_coupling_index(raw) + + +@testing.requires_testing_data +def test_scalp_coupling_index_multi_wavelength(): + """Validate SCI min-correlation logic for >=3 wavelengths. + + Similar to test in test_scalp_coupling_index, considers cases + specific to multi-wavelength data. + """ + raw = optical_density(read_raw_snirf(fname_labnirs_multi_wavelength)) + times = np.arange(raw.n_times) / raw.info["sfreq"] + signal = np.sin(2 * np.pi * 1.0 * times) + 1 + assert len(raw.ch_names) >= 15 * 3 + rng = np.random.default_rng(3289745) + + # pre-determined expected results + expected = [] + # group 1: perfect correlation; sci = 1 + raw._data[0] = signal + raw._data[1] = signal + raw._data[2] = signal + expected.extend([1.0] * 3) + # group 2: scale invariance; sci = 1 + raw._data[3] = signal + raw._data[4] = signal * 0.3 + raw._data[5] = signal + expected.extend([1.0] * 3) + # group 3: anti-correlation; minimum value taken, sci = -1 + raw._data[6] = signal + raw._data[7] = signal + raw._data[8] = -signal + expected.extend([-1.0] * 3) + # group 4: one zero std channel; minimum value is sci = 0 + raw._data[9] = 0.0 + raw._data[10] = signal + raw._data[11] = signal + expected.extend([0.0] * 3) + # group 5: three zero std channels; all sci = 0 + raw._data[12] = 0.0 + raw._data[13] = 1.0 + raw._data[14] = 2.0 + expected.extend([0.0] * 3) + # group 6: mixed: 1 signal + 1 negative + 1 random (lowest wins) + raw._data[15] = signal + raw._data[16] = rng.random(signal.shape) + raw._data[17] = -signal + expected.extend([-1.0] * 3) + + # exact results unknown + # group 7: 1 uncorrelated signal out of 3; sci < 0.5 + raw._data[18] = signal + raw._data[19] = rng.random(signal.shape) + raw._data[20] = signal + # group 8: 2 uncorrelated signals out of 3; sci < 0.5 + raw._data[21] = rng.random(signal.shape) + raw._data[22] = rng.random(signal.shape) + raw._data[23] = signal + # group 9: 3 uncorrelated signals; sci < 0.5 + raw._data[24] = rng.random(signal.shape) + raw._data[25] = rng.random(signal.shape) + raw._data[26] = rng.random(signal.shape) + # groups 10-12: ordering invariance; all must be the same + rand1 = rng.random(signal.shape) + rand2 = rng.random(signal.shape) + rand3 = rng.random(signal.shape) + raw._data[27] = rand1 + raw._data[28] = rand2 + raw._data[29] = rand3 + raw._data[30] = rand2 + raw._data[31] = rand1 + raw._data[32] = rand3 + raw._data[33] = rand3 + raw._data[34] = rand1 + raw._data[35] = rand2 + + sci = scalp_coupling_index(raw) + + assert_allclose(sci[:18], expected, atol=1e-4) + for ii in range(18, 27): + assert np.abs(sci[ii]) < 0.5 + assert_allclose(sci[28:36], sci[27], atol=1e-4)