From 054ad12f8bf3bbb166056c1075f3ead122f52956 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 24 Apr 2025 20:09:57 +0200 Subject: [PATCH 01/57] Add check --- mne/utils/__init__.pyi | 2 ++ mne/utils/check.py | 72 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/mne/utils/__init__.pyi b/mne/utils/__init__.pyi index 46d272e972d..b793c1d66a1 100644 --- a/mne/utils/__init__.pyi +++ b/mne/utils/__init__.pyi @@ -34,6 +34,7 @@ __all__ = [ "_check_eeglabio_installed", "_check_event_id", "_check_fname", + "_check_forbidden_values", "_check_freesurfer_home", "_check_head_radius", "_check_if_nan", @@ -232,6 +233,7 @@ from .check import ( _check_eeglabio_installed, _check_event_id, _check_fname, + _check_forbidden_values, _check_freesurfer_home, _check_head_radius, _check_if_nan, diff --git a/mne/utils/check.py b/mne/utils/check.py index 085c51b6996..998b5561752 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -609,11 +609,11 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" check_types = sum( ( - (type(None),) - if type_ is None - else (type_,) - if not isinstance(type_, str) - else _multi[type_] + ( + (type(None),) + if type_ is None + else (type_,) if not isinstance(type_, str) else _multi[type_] + ) for type_ in types ), (), @@ -622,11 +622,11 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" if not isinstance(item, check_types): if type_name is None: type_name = [ - "None" - if cls_ is None - else cls_.__name__ - if not isinstance(cls_, str) - else cls_ + ( + "None" + if cls_ is None + else cls_.__name__ if not isinstance(cls_, str) else cls_ + ) for cls_ in types ] if len(type_name) == 1: @@ -932,6 +932,58 @@ def _check_option(parameter, value, allowed_values, extra=""): ) +def _check_forbidden_values(parameter, value, invalid_values, extra=""): + """Check the value of a parameter against a list of invalid options. + + Return the value if it is valid, otherwise raise a ValueError with a + readable error message. + + Parameters + ---------- + parameter : str + The name of the parameter to check. This is used in the error message. + value : any type + The value of the parameter to check. + invalid_values : list + The list of forbidden values for the parameter. + extra : str + Extra string to append to the invalid value sentence, e.g. + "when using ico mode". + + Raises + ------ + ValueError + When the value of the parameter is one of the invalid options. + + Returns + ------- + value : any type + The value if it is valid. + """ + if value not in invalid_values: + return value + + # Prepare a nice error message for the user + extra = f" {extra}" if extra else extra + msg = ( + "Invalid value for the '{parameter}' parameter{extra}. " + "{forbidden}, but got {value!r} instead." + ) + invalid_values = list(invalid_values) # e.g., if a dict was given + if len(invalid_values) == 1: + forbidden = f"The following value is not allowed: {repr(invalid_values[0])}" + else: + forbidden = "The following values are not allowed: " + if len(invalid_values) == 2: + forbidden += " and ".join(repr(v) for v in invalid_values) + else: + forbidden += ", ".join(repr(v) for v in invalid_values[:-1]) + forbidden += f", and {repr(invalid_values[-1])}" + raise ValueError( + msg.format(parameter=parameter, forbidden=forbidden, value=value, extra=extra) + ) + + def _check_all_same_channel_names(instances): """Check if a collection of instances all have the same channels.""" ch_names = instances[0].info["ch_names"] From ee92613283f66942538f0397528f892a3731a72d Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 24 Apr 2025 20:10:44 +0200 Subject: [PATCH 02/57] Add details attribute to annotations --- mne/annotations.py | 177 +++++++++++++++++++++++++++++++++++++++------ mne/epochs.py | 22 ++++-- 2 files changed, 171 insertions(+), 28 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 629ee7b20cb..ff4b3f2fc3b 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -31,6 +31,7 @@ ) from .utils import ( _check_dict_keys, + _check_forbidden_values, _check_dt, _check_fname, _check_option, @@ -58,7 +59,43 @@ _datetime = datetime -def _check_o_d_s_c(onset, duration, description, ch_names): +class DetailsDict(dict): + def __setitem__(self, key: str, value: str | int | float | None) -> None: + _validate_type(key, str, "key", "string") + _check_forbidden_values( + "details key", + key, + ["onset", "duration", "description", "ch_names"], + ) + _validate_type( + value, (str, int, float, None), "value", "string, int, float or None" + ) + return super().__setitem__(key, value) + + +def _validate_details(details, length: int): + _validate_type(details, (None, list), "details") + if details is None: + return [None] * length + if len(details) != length: + raise ValueError( + f"Details must be None or a list of length {length}, got " + f"{len(details)}." + ) + for i, d in enumerate(details): + _validate_type(d, (dict, DetailsDict, None), f"details[{i}]", "dict or None") + out = [] + for d in details: + if d is None: + out.append(None) + else: + dd = DetailsDict() + dd.update(d) + out.append(dd) + return out + + +def _check_o_d_s_c_d(onset, duration, description, ch_names, details): onset = np.atleast_1d(np.array(onset, dtype=float)) if onset.ndim != 1: raise ValueError( @@ -100,7 +137,9 @@ def _check_o_d_s_c(onset, duration, description, ch_names): f"equal in sizes, got {len(onset)}, {len(duration)}, " f"{len(description)}, and {len(ch_names)}." ) - return onset, duration, description, ch_names + + details = _validate_details(details, len(onset)) + return onset, duration, description, ch_names, details def _ndarray_ch_names(ch_names): @@ -146,6 +185,11 @@ class Annotations: %(ch_names_annot)s .. versionadded:: 0.23 + details : list[dict | None] | None + Optional list fo dicts containing additional details for each annotation. + The number of items must match the number of annotations. + + .. versionadded:: 1.10.0 See Also -------- @@ -274,10 +318,12 @@ class Annotations: :meth:`Raw.save() ` notes for details. """ # noqa: E501 - def __init__(self, onset, duration, description, orig_time=None, ch_names=None): + def __init__( + self, onset, duration, description, orig_time=None, ch_names=None, details=None + ): self._orig_time = _handle_meas_date(orig_time) - self.onset, self.duration, self.description, self.ch_names = _check_o_d_s_c( - onset, duration, description, ch_names + self.onset, self.duration, self.description, self.ch_names, self.details = ( + _check_o_d_s_c_d(onset, duration, description, ch_names, details) ) self._sort() # ensure we're sorted @@ -286,6 +332,26 @@ def orig_time(self): """The time base of the Annotations.""" return self._orig_time + @property + def details(self): + """The details of the Annotations.""" + return self._details + + @details.setter + def details(self, details): + self._details = _validate_details(details, len(self.onset)) + + @property + def details_columns(self) -> set[str]: + """The set containing all the keys in all details dicts.""" + return {k for d in self.details if d is not None for k in d.keys()} + + @property + def details_data_frame(self): + """The details of the Annotations as a DataFrame.""" + pd = _check_pandas_installed(strict=True) + return pd.DataFrame([d if d is not None else {} for d in self.details]) + def __eq__(self, other): """Compare to another Annotations instance.""" if not isinstance(other, Annotations): @@ -339,7 +405,11 @@ def __iadd__(self, other): f"{self.orig_time} != {other.orig_time})" ) return self.append( - other.onset, other.duration, other.description, other.ch_names + other.onset, + other.duration, + other.description, + other.ch_names, + other.details, ) def __iter__(self): @@ -350,7 +420,7 @@ def __iter__(self): for idx in range(len(self.onset)): yield self.__getitem__(idx, with_ch_names=with_ch_names) - def __getitem__(self, key, *, with_ch_names=None): + def __getitem__(self, key, *, with_ch_names=None, with_details=True): """Propagate indexing and slicing to the underlying numpy structure.""" if isinstance(key, int_like): out_keys = ("onset", "duration", "description", "orig_time") @@ -363,6 +433,9 @@ def __getitem__(self, key, *, with_ch_names=None): if with_ch_names or (with_ch_names is None and self._any_ch_names()): out_keys += ("ch_names",) out_vals += (self.ch_names[key],) + if with_details: + out_keys += ("details",) + out_vals += (self.details[key],) return OrderedDict(zip(out_keys, out_vals)) else: key = list(key) if isinstance(key, tuple) else key @@ -372,10 +445,11 @@ def __getitem__(self, key, *, with_ch_names=None): description=self.description[key], orig_time=self.orig_time, ch_names=self.ch_names[key], + details=[self.details[i] for i in np.arange(len(self.details))[key]], ) @fill_doc - def append(self, onset, duration, description, ch_names=None): + def append(self, onset, duration, description, ch_names=None, details=None): """Add an annotated segment. Operates inplace. Parameters @@ -391,6 +465,11 @@ def append(self, onset, duration, description, ch_names=None): %(ch_names_annot)s .. versionadded:: 0.23 + details : list[dict | None] | None + Optional list fo dicts containing additional details for each annotation. + The number of items must match the number of annotations. + + .. versionadded:: 1.10.0 Returns ------- @@ -403,13 +482,14 @@ def append(self, onset, duration, description, ch_names=None): to not only ``list.append``, but also `list.extend `__. """ # noqa: E501 - onset, duration, description, ch_names = _check_o_d_s_c( - onset, duration, description, ch_names + onset, duration, description, ch_names, details = _check_o_d_s_c_d( + onset, duration, description, ch_names, details ) self.onset = np.append(self.onset, onset) self.duration = np.append(self.duration, duration) self.description = np.append(self.description, description) self.ch_names = np.append(self.ch_names, ch_names) + self.details = self.details + details self._sort() return self @@ -436,6 +516,11 @@ def delete(self, idx): self.duration = np.delete(self.duration, idx) self.description = np.delete(self.description, idx) self.ch_names = np.delete(self.ch_names, idx) + if isinstance(idx, int): + del self.details[idx] + else: + for i in np.sort(np.arange(len(self.details))[idx])[::-1]: + del self.details[i] @fill_doc def to_data_frame(self, time_format="datetime"): @@ -466,6 +551,7 @@ def to_data_frame(self, time_format="datetime"): if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) + df = pd.concat([df, self.details_data_frame], axis=1, ignore_index=True) return df def count(self): @@ -567,6 +653,7 @@ def _sort(self): self.duration = self.duration[order] self.description = self.description[order] self.ch_names = self.ch_names[order] + self.details = [self.details[i] for i in order] @verbose def crop( @@ -619,10 +706,12 @@ def crop( ) logger.debug(f"Cropping annotations {absolute_tmin} - {absolute_tmax}") - onsets, durations, descriptions, ch_names = [], [], [], [] + onsets, durations, descriptions, ch_names, details = [], [], [], [], [] out_of_bounds, clip_left_elem, clip_right_elem = [], [], [] - for idx, (onset, duration, description, ch) in enumerate( - zip(self.onset, self.duration, self.description, self.ch_names) + for idx, (onset, duration, description, ch, detail) in enumerate( + zip( + self.onset, self.duration, self.description, self.ch_names, self.details + ) ): # if duration is NaN behave like a zero if np.isnan(duration): @@ -660,12 +749,14 @@ def crop( ) descriptions.append(description) ch_names.append(ch) + details.append(detail) logger.debug(f"Cropping complete (kept {len(onsets)})") self.onset = np.array(onsets, float) self.duration = np.array(durations, float) assert (self.duration >= 0).all() self.description = np.array(descriptions, dtype=str) self.ch_names = _ndarray_ch_names(ch_names) + self.details = details if emit_warning: omitted = np.array(out_of_bounds).sum() @@ -892,6 +983,7 @@ def get_annotations_per_epoch(self): this_annot["onset"] - this_tzero, this_annot["duration"], this_annot["description"], + this_annot["details"], ) # ...then add it to the correct sublist of `epoch_annot_list` epoch_annot_list[epo_ix].append(annot) @@ -957,6 +1049,7 @@ def add_annotations_to_metadata(self, overwrite=False): # onsets, durations, and descriptions epoch_annot_list = self.get_annotations_per_epoch() onset, duration, description = [], [], [] + details = {k: [] for k in self.annotations.details_columns} for epoch_annot in epoch_annot_list: for ix, annot_prop in enumerate((onset, duration, description)): entry = [annot[ix] for annot in epoch_annot] @@ -966,12 +1059,20 @@ def add_annotations_to_metadata(self, overwrite=False): entry = np.round(entry, decimals=12).tolist() annot_prop.append(entry) + for k in details.keys(): + entry = [ + None if annot[3] is None else annot[3].get(k, None) + for annot in epoch_annot + ] + details[k].append(entry) # Create a new Annotations column that is instantiated as an empty # list per Epoch. metadata["annot_onset"] = pd.Series(onset) metadata["annot_duration"] = pd.Series(duration) metadata["annot_description"] = pd.Series(description) + for k, v in details.items(): + metadata[f"annot_{k}"] = pd.Series(v) # reset the metadata self.metadata = metadata @@ -1100,6 +1201,8 @@ def _write_annotations(fid, annotations): write_string( fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(tuple(annotations.ch_names)) ) + if any(d is not None for d in annotations.details): + write_string(fid, FIFF.FIFF_FREE_LIST, json.dumps(annotations.details)) end_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) @@ -1328,28 +1431,51 @@ def _read_annotations_txt_parse_header(fname): def is_orig_time(x): return x.startswith("# orig_time :") + def is_columns(x): + return x.startswith("# onset, duration, description") + with open(fname) as fid: header = list(takewhile(lambda x: x.startswith("#"), fid)) orig_values = [h[13:].strip() for h in header if is_orig_time(h)] orig_values = [_handle_meas_date(orig) for orig in orig_values if _is_iso8601(orig)] - return None if not orig_values else orig_values[0] + columns = [[c.strip() for c in h[2:].split(",")] for h in header if is_columns(h)] + + return None if not orig_values else orig_values[0], ( + None if not columns else columns[0] + ) def _read_annotations_txt(fname): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") out = np.loadtxt(fname, delimiter=",", dtype=np.bytes_, unpack=True) - ch_names = None + orig_time, columns = _read_annotations_txt_parse_header(fname) + ch_names = details = None if len(out) == 0: onset, duration, desc = [], [], [] else: - _check_option("text header", len(out), (3, 4)) - if len(out) == 3: - onset, duration, desc = out + if columns is None: + _check_option("text header", len(out), (3, 4)) + columns = ["onset", "duration", "description"] + ( + ["ch_names"] if len(out) == 4 else [] + ) else: - onset, duration, desc, ch_names = out + _check_option( + "text header", columns[:3], (["onset", "duration", "description"],) + ) + _check_option("text header len", len(out), (len(columns),)) + onset, duration, desc = out[:3] + i_col = 3 + if len(columns) > i_col and columns[i_col] == "ch_names": + ch_names = out[i_col] + i_col += 1 + if len(columns) > i_col: + details = [ + {columns[j_col]: out[j_col][i] for j_col in range(i_col, len(columns))} + for i in range(len(onset)) + ] onset = [float(o.decode()) for o in np.atleast_1d(onset)] duration = [float(d.decode()) for d in np.atleast_1d(duration)] @@ -1360,14 +1486,13 @@ def _read_annotations_txt(fname): for ci, ch in enumerate(ch_names) ] - orig_time = _read_annotations_txt_parse_header(fname) - annotations = Annotations( onset=onset, duration=duration, description=desc, orig_time=orig_time, ch_names=ch_names, + details=details, ) return annotations @@ -1380,7 +1505,7 @@ def _read_annotations_fif(fid, tree): annotations = None else: annot_data = annot_data[0] - orig_time = ch_names = None + orig_time = ch_names = details = None onset, duration, description = list(), list(), list() for ent in annot_data["directory"]: kind = ent.kind @@ -1402,8 +1527,14 @@ def _read_annotations_fif(fid, tree): orig_time = tuple(orig_time) # new way elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG: ch_names = tuple(tuple(x) for x in json.loads(tag.data)) + elif kind == FIFF.FIFF_FREE_LIST: + details = json.loads(tag.data) assert len(onset) == len(duration) == len(description) - annotations = Annotations(onset, duration, description, orig_time, ch_names) + if details is not None: + assert len(details) == len(onset) + annotations = Annotations( + onset, duration, description, orig_time, ch_names, details + ) return annotations diff --git a/mne/epochs.py b/mne/epochs.py index 96f247875d9..3d799aa3dd7 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -2465,11 +2465,13 @@ def equalize_event_counts( # 2b. for non-tag ids, just pass them directly # 3. do this for every input event_ids = [ - [ - k for k in ids if all(tag in k.split("/") for tag in id_) - ] # ids matching all tags - if all(id__ not in ids for id__ in id_) - else id_ # straight pass for non-tag inputs + ( + [ + k for k in ids if all(tag in k.split("/") for tag in id_) + ] # ids matching all tags + if all(id__ not in ids for id__ in id_) + else id_ + ) # straight pass for non-tag inputs for id_ in event_ids ] for ii, id_ in enumerate(event_ids): @@ -3575,6 +3577,16 @@ def __init__( raw, events, event_id, annotations, on_missing ) + # add the annotations.details to the metadata + if not all(d is None for d in annotations.details): + if metadata is None: + metadata = annotations.details_data_frame + else: + pd = _check_pandas_installed(strict=True) + details_df = annotations.details_data_frame + details_df.set_index(metadata.index, inplace=True) + metadata = pd.concat([metadata, details_df], axis=1, ignore_index=False) + # call BaseEpochs constructor super().__init__( info, From fdceb5b3f9eafeea03761679120709205b95329e Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 24 Apr 2025 20:10:53 +0200 Subject: [PATCH 03/57] Update tests --- mne/tests/test_annotations.py | 53 +++++++++++++++++++++++++++++------ mne/tests/test_epochs.py | 27 ++++++++++++------ 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 4d0db170e2a..c2a9a69424b 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -630,10 +630,24 @@ def test_annotation_epoching(): assert_equal([0, 2, 4], epochs.selection) -def test_annotation_concat(): +@pytest.mark.parametrize("with_details", [True, False]) +def test_annotation_concat(with_details): """Test if two Annotations objects can be concatenated.""" + details = None + if with_details: + details = [ + {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None}, + None, + None, + ] a = Annotations([1, 2, 3], [5, 5, 8], ["a", "b", "c"], ch_names=[["1"], ["2"], []]) - b = Annotations([11, 12, 13], [1, 2, 2], ["x", "y", "z"], ch_names=[[], ["3"], []]) + b = Annotations( + [11, 12, 13], + [1, 2, 2], + ["x", "y", "z"], + ch_names=[[], ["3"], []], + details=details, + ) # test + operator (does not modify a or b) c = a + b @@ -656,6 +670,10 @@ def test_annotation_concat(): assert_equal(len(a), 6) assert_equal(len(b), 3) + if with_details: + all_details = [None] * 3 + details + assert all(c.details[i] == all_details[i] for i in range(len(all_details))) + # test += operator (modifies a in place) b._orig_time = _handle_meas_date(1038942070.7201) with pytest.raises(ValueError, match="orig_time should be the same"): @@ -963,9 +981,10 @@ def _assert_annotations_equal(a, b, tol=0): _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) -@pytest.fixture(scope="function", params=("ch_names", "fmt")) -def dummy_annotation_file(tmp_path_factory, ch_names, fmt): +@pytest.fixture(scope="function", params=("ch_names", "fmt", "with_details")) +def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_details): """Create csv file for testing.""" + details_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} if fmt == "csv": content = ( "onset,duration,description\n" @@ -982,7 +1001,10 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt): ) else: assert fmt == "fif" - content = Annotations([0, 9], [1, 2.425], ["AA", "BB"], orig_time=_ORIG_TIME) + details = [details_row0, None] if with_details else None + content = Annotations( + [0, 9], [1, 2.425], ["AA", "BB"], orig_time=_ORIG_TIME, details=details + ) if ch_names: if isinstance(content, Annotations): @@ -995,6 +1017,13 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt): content[-1] += ",MEG0111:MEG2563" content = "\n".join(content) + if with_details and not isinstance(content, Annotations): + content = content.splitlines() + content[-3] += "," + ",".join(details_row0.keys()) + content[-2] += "," + ",".join([str(v) for v in details_row0.values()]) + content[-1] += "," * len(details_row0) + content = "\n".join(content) + fname = tmp_path_factory.mktemp("data") / f"annotations-annot.{fmt}" if isinstance(content, str): with open(fname, "w") as f: @@ -1006,7 +1035,8 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt): @pytest.mark.parametrize("ch_names", (False, True)) @pytest.mark.parametrize("fmt", [pytest.param("csv", marks=needs_pandas), "txt", "fif"]) -def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names): +@pytest.mark.parametrize("with_details", [True, False]) +def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names, with_details): """Test CSV, TXT, and FIF input/output (which support ch_names).""" annot = read_annotations(dummy_annotation_file) assert annot.orig_time == _ORIG_TIME @@ -1123,7 +1153,7 @@ def test_read_annotation_txt_header(tmp_path): fname = tmp_path / "header.txt" with open(fname, "w") as f: f.write(content) - orig_time = _read_annotations_txt_parse_header(fname) + orig_time, _ = _read_annotations_txt_parse_header(fname) want = datetime.fromtimestamp(1038942071.7201, timezone.utc) assert orig_time == want @@ -1178,29 +1208,34 @@ def test_annotations_slices(): NUM_ANNOT = 5 EXPECTED_ONSETS = EXPECTED_DURATIONS = [x for x in range(NUM_ANNOT)] EXPECTED_DESCS = [x.__repr__() for x in range(NUM_ANNOT)] + DETAILS_ROW = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} + EXPECTED_DETAILS = [DETAILS_ROW] * NUM_ANNOT annot = Annotations( onset=EXPECTED_ONSETS, duration=EXPECTED_DURATIONS, description=EXPECTED_DESCS, orig_time=None, + details=EXPECTED_DETAILS, ) # Indexing returns a copy. So this has no effect in annot annot[0]["onset"] = 42 annot[0]["duration"] = 3.14 annot[0]["description"] = "foobar" + annot[0]["details"] = DETAILS_ROW annot[:1].onset[0] = 42 annot[:1].duration[0] = 3.14 annot[:1].description[0] = "foobar" + annot[:1].details[0] = DETAILS_ROW # Slicing with single element returns a dictionary for ii in EXPECTED_ONSETS: assert annot[ii] == dict( zip( - ["onset", "duration", "description", "orig_time"], - [ii, ii, str(ii), None], + ["onset", "duration", "description", "orig_time", "details"], + [ii, ii, str(ii), None, DETAILS_ROW], ) ) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 88f2d9cdc13..b6ebb1cf861 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -479,12 +479,12 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True assert isinstance(drop_log, tuple), "drop_log should be tuple" - assert all(isinstance(log, tuple) for log in drop_log), ( - "drop_log[ii] should be tuple" - ) - assert all(isinstance(s, str) for log in drop_log for s in log), ( - "drop_log[ii][jj] should be str" - ) + assert all( + isinstance(log, tuple) for log in drop_log + ), "drop_log[ii] should be tuple" + assert all( + isinstance(s, str) for log in drop_log for s in log + ), "drop_log[ii][jj] should be str" def test_reject(): @@ -4917,9 +4917,15 @@ def test_add_channels_picks(): @pytest.mark.parametrize("first_samp", [0, 10]) @pytest.mark.parametrize( - "meas_date, orig_date", [[None, None], [np.pi, None], [np.pi, timedelta(seconds=1)]] + "meas_date, orig_date, with_details", + [ + [None, None, False], + [np.pi, None, False], + [np.pi, timedelta(seconds=1), False], + [None, None, True], + ], ) -def test_epoch_annotations(first_samp, meas_date, orig_date, tmp_path): +def test_epoch_annotations(first_samp, meas_date, orig_date, with_details, tmp_path): """Test Epoch Annotations from RawArray with dates. Tests the following cases crossed with each other: @@ -4942,11 +4948,14 @@ def test_epoch_annotations(first_samp, meas_date, orig_date, tmp_path): if orig_date is not None: orig_date = meas_date + orig_date ant_dur = 0.1 + details_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} + details = [details_row0, None, None] if with_details else None ants = Annotations( onset=[1.1, 1.2, 2.1], duration=[ant_dur, ant_dur, ant_dur], description=["x", "y", "z"], orig_time=orig_date, + details=details, ) raw.set_annotations(ants) epochs = make_fixed_length_epochs(raw, duration=1, overlap=0.5) @@ -4957,6 +4966,8 @@ def test_epoch_annotations(first_samp, meas_date, orig_date, tmp_path): assert "annot_onset" in metadata.columns assert "annot_duration" in metadata.columns assert "annot_description" in metadata.columns + if with_details: + assert all(f"annot_{k}" in metadata.columns for k in details_row0.keys()) # Test that writing and reading back these new metadata works temp_fname = tmp_path / "test-epo.fif" From a421357bae9cb2b3532e9da3bf3d8342600f0117 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 18:26:26 +0000 Subject: [PATCH 04/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/annotations.py | 5 ++--- mne/epochs.py | 4 +++- mne/tests/test_epochs.py | 12 ++++++------ mne/utils/check.py | 8 ++++++-- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index ff4b3f2fc3b..6b61df82d5d 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -31,9 +31,9 @@ ) from .utils import ( _check_dict_keys, - _check_forbidden_values, _check_dt, _check_fname, + _check_forbidden_values, _check_option, _check_pandas_installed, _check_time_format, @@ -79,8 +79,7 @@ def _validate_details(details, length: int): return [None] * length if len(details) != length: raise ValueError( - f"Details must be None or a list of length {length}, got " - f"{len(details)}." + f"Details must be None or a list of length {length}, got {len(details)}." ) for i, d in enumerate(details): _validate_type(d, (dict, DetailsDict, None), f"details[{i}]", "dict or None") diff --git a/mne/epochs.py b/mne/epochs.py index 3d799aa3dd7..b55f6f3523e 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3585,7 +3585,9 @@ def __init__( pd = _check_pandas_installed(strict=True) details_df = annotations.details_data_frame details_df.set_index(metadata.index, inplace=True) - metadata = pd.concat([metadata, details_df], axis=1, ignore_index=False) + metadata = pd.concat( + [metadata, details_df], axis=1, ignore_index=False + ) # call BaseEpochs constructor super().__init__( diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index b6ebb1cf861..29ee5ace846 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -479,12 +479,12 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True assert isinstance(drop_log, tuple), "drop_log should be tuple" - assert all( - isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple" - assert all( - isinstance(s, str) for log in drop_log for s in log - ), "drop_log[ii][jj] should be str" + assert all(isinstance(log, tuple) for log in drop_log), ( + "drop_log[ii] should be tuple" + ) + assert all(isinstance(s, str) for log in drop_log for s in log), ( + "drop_log[ii][jj] should be str" + ) def test_reject(): diff --git a/mne/utils/check.py b/mne/utils/check.py index 998b5561752..84ea56ab5a4 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -612,7 +612,9 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" ( (type(None),) if type_ is None - else (type_,) if not isinstance(type_, str) else _multi[type_] + else (type_,) + if not isinstance(type_, str) + else _multi[type_] ) for type_ in types ), @@ -625,7 +627,9 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" ( "None" if cls_ is None - else cls_.__name__ if not isinstance(cls_, str) else cls_ + else cls_.__name__ + if not isinstance(cls_, str) + else cls_ ) for cls_ in types ] From 078d8375f635e967ef85bc2d85e76b11ed2d2578 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 24 Apr 2025 20:30:43 +0200 Subject: [PATCH 05/57] fix coquille --- mne/annotations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/annotations.py b/mne/annotations.py index ff4b3f2fc3b..514d00ee599 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -551,7 +551,7 @@ def to_data_frame(self, time_format="datetime"): if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) - df = pd.concat([df, self.details_data_frame], axis=1, ignore_index=True) + df = pd.concat([df, self.details_data_frame], axis=1) return df def count(self): From 68fb8b3a804ff8df227fe23b3d0471deaed5676c Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 24 Apr 2025 20:33:21 +0200 Subject: [PATCH 06/57] Update change log --- doc/changes/devel/13228.newfeature.rst | 1 + doc/changes/names.inc | 1 + 2 files changed, 2 insertions(+) create mode 100644 doc/changes/devel/13228.newfeature.rst diff --git a/doc/changes/devel/13228.newfeature.rst b/doc/changes/devel/13228.newfeature.rst new file mode 100644 index 00000000000..2681654210d --- /dev/null +++ b/doc/changes/devel/13228.newfeature.rst @@ -0,0 +1 @@ +Add a ``details`` attribute to :class:`mne.Annotations`, by `Pierre Guetschel`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index ad162ee8f68..8e46884dc76 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -231,6 +231,7 @@ .. _Peter Molfese: https://github.com/pmolfese .. _Phillip Alday: https://palday.bitbucket.io .. _Pierre Ablin: https://pierreablin.com +.. _Pierre Guetschel: https://github.com/PierreGtch .. _Pierre-Antoine Bannier: https://github.com/PABannier .. _Ping-Keng Jao: https://github.com/nafraw .. _Proloy Das: https://github.com/proloyd From ecac5e2f95b815a88650655091861556bf70a3d3 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 24 Apr 2025 20:45:07 +0200 Subject: [PATCH 07/57] Make DetailsDict private and add docstring --- mne/annotations.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index fb16a7655c0..747ab925e1a 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -59,7 +59,13 @@ _datetime = datetime -class DetailsDict(dict): +class _DetailsDict(dict): + """A dictionary for storing details of annotations. + + The keys of the dictionary are strings, and the values can be + strings, integers, floats, or None. + """ + def __setitem__(self, key: str, value: str | int | float | None) -> None: _validate_type(key, str, "key", "string") _check_forbidden_values( @@ -82,13 +88,13 @@ def _validate_details(details, length: int): f"Details must be None or a list of length {length}, got {len(details)}." ) for i, d in enumerate(details): - _validate_type(d, (dict, DetailsDict, None), f"details[{i}]", "dict or None") + _validate_type(d, (dict, _DetailsDict, None), f"details[{i}]", "dict or None") out = [] for d in details: if d is None: out.append(None) else: - dd = DetailsDict() + dd = _DetailsDict() dd.update(d) out.append(dd) return out From ec348337b55f31851ab77e885c4fc7ad623cd253 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Fri, 25 Apr 2025 08:57:34 +0200 Subject: [PATCH 08/57] Fix test --- mne/io/egi/tests/test_egi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index 923c5ce925a..8e3275a733e 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -212,7 +212,7 @@ def test_io_egi_mff(events_as_annotations): if events_as_annotations: # Grab the first annotation. Should be the first "DIN1" event. assert len(raw.annotations) - onset, dur, desc, _ = raw.annotations[0].values() + onset, dur, desc, _, _ = raw.annotations[0].values() assert_allclose(onset, 2.438) assert np.isclose(dur, 0) assert desc == "DIN1" From 7b1d70bc6f8d79a1dd26805ff76e7cfde528d345 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Fri, 25 Apr 2025 10:11:35 +0200 Subject: [PATCH 09/57] ise int_like instead of int --- mne/annotations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 747ab925e1a..9b445c55d26 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -521,9 +521,9 @@ def delete(self, idx): self.duration = np.delete(self.duration, idx) self.description = np.delete(self.description, idx) self.ch_names = np.delete(self.ch_names, idx) - if isinstance(idx, int): + if isinstance(idx, int_like): del self.details[idx] - else: + elif len(idx) > 0: for i in np.sort(np.arange(len(self.details))[idx])[::-1]: del self.details[i] From fe6e4a1f1ad4a1d114d3f269e751f654a9610cb5 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Date: Sat, 26 Apr 2025 09:54:16 +0200 Subject: [PATCH 10/57] Apply suggestions from code review Co-authored-by: Alexandre Gramfort Co-authored-by: Daniel McCloy --- mne/annotations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 9b445c55d26..6113b44e51e 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -60,7 +60,7 @@ class _DetailsDict(dict): - """A dictionary for storing details of annotations. + """A dictionary for storing extra fields of annotations. The keys of the dictionary are strings, and the values can be strings, integers, floats, or None. @@ -327,7 +327,7 @@ def __init__( self, onset, duration, description, orig_time=None, ch_names=None, details=None ): self._orig_time = _handle_meas_date(orig_time) - self.onset, self.duration, self.description, self.ch_names, self.details = ( + self.onset, self.duration, self.description, self.ch_names, self._details = ( _check_o_d_s_c_d(onset, duration, description, ch_names, details) ) self._sort() # ensure we're sorted @@ -471,7 +471,7 @@ def append(self, onset, duration, description, ch_names=None, details=None): .. versionadded:: 0.23 details : list[dict | None] | None - Optional list fo dicts containing additional details for each annotation. + Optional list of dicts containing additional details for each annotation. The number of items must match the number of annotations. .. versionadded:: 1.10.0 @@ -494,7 +494,7 @@ def append(self, onset, duration, description, ch_names=None, details=None): self.duration = np.append(self.duration, duration) self.description = np.append(self.description, description) self.ch_names = np.append(self.ch_names, ch_names) - self.details = self.details + details + self.details.extend(details) self._sort() return self From 9bcbc3feac743a2c2b116d467b2660a880690de6 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 09:59:51 +0200 Subject: [PATCH 11/57] Remove _check_forbidden_values --- mne/annotations.py | 8 ++----- mne/utils/check.py | 60 ++-------------------------------------------- 2 files changed, 4 insertions(+), 64 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 6113b44e51e..7912c8590c2 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -33,7 +33,6 @@ _check_dict_keys, _check_dt, _check_fname, - _check_forbidden_values, _check_option, _check_pandas_installed, _check_time_format, @@ -68,11 +67,8 @@ class _DetailsDict(dict): def __setitem__(self, key: str, value: str | int | float | None) -> None: _validate_type(key, str, "key", "string") - _check_forbidden_values( - "details key", - key, - ["onset", "duration", "description", "ch_names"], - ) + if key in ("onset", "duration", "description", "ch_names"): + raise ValueError(f"Key '{key}' is reserved and cannot be used in details.") _validate_type( value, (str, int, float, None), "value", "string, int, float or None" ) diff --git a/mne/utils/check.py b/mne/utils/check.py index 84ea56ab5a4..cd25577dac3 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -612,9 +612,7 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" ( (type(None),) if type_ is None - else (type_,) - if not isinstance(type_, str) - else _multi[type_] + else (type_,) if not isinstance(type_, str) else _multi[type_] ) for type_ in types ), @@ -627,9 +625,7 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" ( "None" if cls_ is None - else cls_.__name__ - if not isinstance(cls_, str) - else cls_ + else cls_.__name__ if not isinstance(cls_, str) else cls_ ) for cls_ in types ] @@ -936,58 +932,6 @@ def _check_option(parameter, value, allowed_values, extra=""): ) -def _check_forbidden_values(parameter, value, invalid_values, extra=""): - """Check the value of a parameter against a list of invalid options. - - Return the value if it is valid, otherwise raise a ValueError with a - readable error message. - - Parameters - ---------- - parameter : str - The name of the parameter to check. This is used in the error message. - value : any type - The value of the parameter to check. - invalid_values : list - The list of forbidden values for the parameter. - extra : str - Extra string to append to the invalid value sentence, e.g. - "when using ico mode". - - Raises - ------ - ValueError - When the value of the parameter is one of the invalid options. - - Returns - ------- - value : any type - The value if it is valid. - """ - if value not in invalid_values: - return value - - # Prepare a nice error message for the user - extra = f" {extra}" if extra else extra - msg = ( - "Invalid value for the '{parameter}' parameter{extra}. " - "{forbidden}, but got {value!r} instead." - ) - invalid_values = list(invalid_values) # e.g., if a dict was given - if len(invalid_values) == 1: - forbidden = f"The following value is not allowed: {repr(invalid_values[0])}" - else: - forbidden = "The following values are not allowed: " - if len(invalid_values) == 2: - forbidden += " and ".join(repr(v) for v in invalid_values) - else: - forbidden += ", ".join(repr(v) for v in invalid_values[:-1]) - forbidden += f", and {repr(invalid_values[-1])}" - raise ValueError( - msg.format(parameter=parameter, forbidden=forbidden, value=value, extra=extra) - ) - - def _check_all_same_channel_names(instances): """Check if a collection of instances all have the same channels.""" ch_names = instances[0].info["ch_names"] From f8999b2b7c2c1de86d1c09541aac9d36f768693f Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 10:01:35 +0200 Subject: [PATCH 12/57] Remove _check_forbidden_values --- mne/utils/__init__.pyi | 2 -- 1 file changed, 2 deletions(-) diff --git a/mne/utils/__init__.pyi b/mne/utils/__init__.pyi index b793c1d66a1..46d272e972d 100644 --- a/mne/utils/__init__.pyi +++ b/mne/utils/__init__.pyi @@ -34,7 +34,6 @@ __all__ = [ "_check_eeglabio_installed", "_check_event_id", "_check_fname", - "_check_forbidden_values", "_check_freesurfer_home", "_check_head_radius", "_check_if_nan", @@ -233,7 +232,6 @@ from .check import ( _check_eeglabio_installed, _check_event_id, _check_fname, - _check_forbidden_values, _check_freesurfer_home, _check_head_radius, _check_if_nan, From e8f6728043c09d88236a9cd80fda6f69c6a7a450 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 10:19:31 +0200 Subject: [PATCH 13/57] Rename details to extras --- mne/annotations.py | 140 +++++++++++++++++----------------- mne/epochs.py | 12 +-- mne/tests/test_annotations.py | 54 ++++++------- mne/tests/test_epochs.py | 14 ++-- 4 files changed, 110 insertions(+), 110 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 7912c8590c2..5443423a2b5 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -58,7 +58,7 @@ _datetime = datetime -class _DetailsDict(dict): +class _AnnotationsExtrasDict(dict): """A dictionary for storing extra fields of annotations. The keys of the dictionary are strings, and the values can be @@ -68,35 +68,37 @@ class _DetailsDict(dict): def __setitem__(self, key: str, value: str | int | float | None) -> None: _validate_type(key, str, "key", "string") if key in ("onset", "duration", "description", "ch_names"): - raise ValueError(f"Key '{key}' is reserved and cannot be used in details.") + raise ValueError(f"Key '{key}' is reserved and cannot be used in extras.") _validate_type( value, (str, int, float, None), "value", "string, int, float or None" ) return super().__setitem__(key, value) -def _validate_details(details, length: int): - _validate_type(details, (None, list), "details") - if details is None: +def _validate_extras(extras, length: int): + _validate_type(extras, (None, list), "extras") + if extras is None: return [None] * length - if len(details) != length: + if len(extras) != length: raise ValueError( - f"Details must be None or a list of length {length}, got {len(details)}." + f"extras must be None or a list of length {length}, got {len(extras)}." + ) + for i, d in enumerate(extras): + _validate_type( + d, (dict, _AnnotationsExtrasDict, None), f"extras[{i}]", "dict or None" ) - for i, d in enumerate(details): - _validate_type(d, (dict, _DetailsDict, None), f"details[{i}]", "dict or None") out = [] - for d in details: + for d in extras: if d is None: out.append(None) else: - dd = _DetailsDict() + dd = _AnnotationsExtrasDict() dd.update(d) out.append(dd) return out -def _check_o_d_s_c_d(onset, duration, description, ch_names, details): +def _check_o_d_s_c_e(onset, duration, description, ch_names, extras): onset = np.atleast_1d(np.array(onset, dtype=float)) if onset.ndim != 1: raise ValueError( @@ -139,8 +141,8 @@ def _check_o_d_s_c_d(onset, duration, description, ch_names, details): f"{len(description)}, and {len(ch_names)}." ) - details = _validate_details(details, len(onset)) - return onset, duration, description, ch_names, details + extras = _validate_extras(extras, len(onset)) + return onset, duration, description, ch_names, extras def _ndarray_ch_names(ch_names): @@ -186,8 +188,8 @@ class Annotations: %(ch_names_annot)s .. versionadded:: 0.23 - details : list[dict | None] | None - Optional list fo dicts containing additional details for each annotation. + extras : list[dict | None] | None + Optional list fo dicts containing extra fields for each annotation. The number of items must match the number of annotations. .. versionadded:: 1.10.0 @@ -320,11 +322,11 @@ class Annotations: """ # noqa: E501 def __init__( - self, onset, duration, description, orig_time=None, ch_names=None, details=None + self, onset, duration, description, orig_time=None, ch_names=None, extras=None ): self._orig_time = _handle_meas_date(orig_time) - self.onset, self.duration, self.description, self.ch_names, self._details = ( - _check_o_d_s_c_d(onset, duration, description, ch_names, details) + self.onset, self.duration, self.description, self.ch_names, self._extras = ( + _check_o_d_s_c_e(onset, duration, description, ch_names, extras) ) self._sort() # ensure we're sorted @@ -334,24 +336,24 @@ def orig_time(self): return self._orig_time @property - def details(self): - """The details of the Annotations.""" - return self._details + def extras(self): + """The extras of the Annotations.""" + return self._extras - @details.setter - def details(self, details): - self._details = _validate_details(details, len(self.onset)) + @extras.setter + def extras(self, extras): + self._extras = _validate_extras(extras, len(self.onset)) @property - def details_columns(self) -> set[str]: - """The set containing all the keys in all details dicts.""" - return {k for d in self.details if d is not None for k in d.keys()} + def extras_columns(self) -> set[str]: + """The set containing all the keys in all extras dicts.""" + return {k for d in self.extras if d is not None for k in d.keys()} @property - def details_data_frame(self): - """The details of the Annotations as a DataFrame.""" + def extras_data_frame(self): + """The extras of the Annotations as a DataFrame.""" pd = _check_pandas_installed(strict=True) - return pd.DataFrame([d if d is not None else {} for d in self.details]) + return pd.DataFrame([d if d is not None else {} for d in self.extras]) def __eq__(self, other): """Compare to another Annotations instance.""" @@ -410,7 +412,7 @@ def __iadd__(self, other): other.duration, other.description, other.ch_names, - other.details, + other.extras, ) def __iter__(self): @@ -421,7 +423,7 @@ def __iter__(self): for idx in range(len(self.onset)): yield self.__getitem__(idx, with_ch_names=with_ch_names) - def __getitem__(self, key, *, with_ch_names=None, with_details=True): + def __getitem__(self, key, *, with_ch_names=None, with_extras=True): """Propagate indexing and slicing to the underlying numpy structure.""" if isinstance(key, int_like): out_keys = ("onset", "duration", "description", "orig_time") @@ -434,9 +436,9 @@ def __getitem__(self, key, *, with_ch_names=None, with_details=True): if with_ch_names or (with_ch_names is None and self._any_ch_names()): out_keys += ("ch_names",) out_vals += (self.ch_names[key],) - if with_details: - out_keys += ("details",) - out_vals += (self.details[key],) + if with_extras: + out_keys += ("extras",) + out_vals += (self.extras[key],) return OrderedDict(zip(out_keys, out_vals)) else: key = list(key) if isinstance(key, tuple) else key @@ -446,11 +448,11 @@ def __getitem__(self, key, *, with_ch_names=None, with_details=True): description=self.description[key], orig_time=self.orig_time, ch_names=self.ch_names[key], - details=[self.details[i] for i in np.arange(len(self.details))[key]], + extras=[self.extras[i] for i in np.arange(len(self.extras))[key]], ) @fill_doc - def append(self, onset, duration, description, ch_names=None, details=None): + def append(self, onset, duration, description, ch_names=None, extras=None): """Add an annotated segment. Operates inplace. Parameters @@ -466,8 +468,8 @@ def append(self, onset, duration, description, ch_names=None, details=None): %(ch_names_annot)s .. versionadded:: 0.23 - details : list[dict | None] | None - Optional list of dicts containing additional details for each annotation. + extras : list[dict | None] | None + Optional list of dicts containing extras fields for each annotation. The number of items must match the number of annotations. .. versionadded:: 1.10.0 @@ -483,14 +485,14 @@ def append(self, onset, duration, description, ch_names=None, details=None): to not only ``list.append``, but also `list.extend `__. """ # noqa: E501 - onset, duration, description, ch_names, details = _check_o_d_s_c_d( - onset, duration, description, ch_names, details + onset, duration, description, ch_names, extras = _check_o_d_s_c_e( + onset, duration, description, ch_names, extras ) self.onset = np.append(self.onset, onset) self.duration = np.append(self.duration, duration) self.description = np.append(self.description, description) self.ch_names = np.append(self.ch_names, ch_names) - self.details.extend(details) + self.extras.extend(extras) self._sort() return self @@ -518,10 +520,10 @@ def delete(self, idx): self.description = np.delete(self.description, idx) self.ch_names = np.delete(self.ch_names, idx) if isinstance(idx, int_like): - del self.details[idx] + del self.extras[idx] elif len(idx) > 0: - for i in np.sort(np.arange(len(self.details))[idx])[::-1]: - del self.details[i] + for i in np.sort(np.arange(len(self.extras))[idx])[::-1]: + del self.extras[i] @fill_doc def to_data_frame(self, time_format="datetime"): @@ -552,7 +554,7 @@ def to_data_frame(self, time_format="datetime"): if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) - df = pd.concat([df, self.details_data_frame], axis=1) + df = pd.concat([df, self.extras_data_frame], axis=1) return df def count(self): @@ -654,7 +656,7 @@ def _sort(self): self.duration = self.duration[order] self.description = self.description[order] self.ch_names = self.ch_names[order] - self.details = [self.details[i] for i in order] + self.extras = [self.extras[i] for i in order] @verbose def crop( @@ -707,12 +709,10 @@ def crop( ) logger.debug(f"Cropping annotations {absolute_tmin} - {absolute_tmax}") - onsets, durations, descriptions, ch_names, details = [], [], [], [], [] + onsets, durations, descriptions, ch_names, extras = [], [], [], [], [] out_of_bounds, clip_left_elem, clip_right_elem = [], [], [] - for idx, (onset, duration, description, ch, detail) in enumerate( - zip( - self.onset, self.duration, self.description, self.ch_names, self.details - ) + for idx, (onset, duration, description, ch, extra) in enumerate( + zip(self.onset, self.duration, self.description, self.ch_names, self.extras) ): # if duration is NaN behave like a zero if np.isnan(duration): @@ -750,14 +750,14 @@ def crop( ) descriptions.append(description) ch_names.append(ch) - details.append(detail) + extras.append(extra) logger.debug(f"Cropping complete (kept {len(onsets)})") self.onset = np.array(onsets, float) self.duration = np.array(durations, float) assert (self.duration >= 0).all() self.description = np.array(descriptions, dtype=str) self.ch_names = _ndarray_ch_names(ch_names) - self.details = details + self.extras = extras if emit_warning: omitted = np.array(out_of_bounds).sum() @@ -984,7 +984,7 @@ def get_annotations_per_epoch(self): this_annot["onset"] - this_tzero, this_annot["duration"], this_annot["description"], - this_annot["details"], + this_annot["extras"], ) # ...then add it to the correct sublist of `epoch_annot_list` epoch_annot_list[epo_ix].append(annot) @@ -1050,7 +1050,7 @@ def add_annotations_to_metadata(self, overwrite=False): # onsets, durations, and descriptions epoch_annot_list = self.get_annotations_per_epoch() onset, duration, description = [], [], [] - details = {k: [] for k in self.annotations.details_columns} + extras = {k: [] for k in self.annotations.extras_columns} for epoch_annot in epoch_annot_list: for ix, annot_prop in enumerate((onset, duration, description)): entry = [annot[ix] for annot in epoch_annot] @@ -1060,19 +1060,19 @@ def add_annotations_to_metadata(self, overwrite=False): entry = np.round(entry, decimals=12).tolist() annot_prop.append(entry) - for k in details.keys(): + for k in extras.keys(): entry = [ None if annot[3] is None else annot[3].get(k, None) for annot in epoch_annot ] - details[k].append(entry) + extras[k].append(entry) # Create a new Annotations column that is instantiated as an empty # list per Epoch. metadata["annot_onset"] = pd.Series(onset) metadata["annot_duration"] = pd.Series(duration) metadata["annot_description"] = pd.Series(description) - for k, v in details.items(): + for k, v in extras.items(): metadata[f"annot_{k}"] = pd.Series(v) # reset the metadata @@ -1202,8 +1202,8 @@ def _write_annotations(fid, annotations): write_string( fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(tuple(annotations.ch_names)) ) - if any(d is not None for d in annotations.details): - write_string(fid, FIFF.FIFF_FREE_LIST, json.dumps(annotations.details)) + if any(d is not None for d in annotations.extras): + write_string(fid, FIFF.FIFF_FREE_LIST, json.dumps(annotations.extras)) end_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) @@ -1453,7 +1453,7 @@ def _read_annotations_txt(fname): warnings.simplefilter("ignore") out = np.loadtxt(fname, delimiter=",", dtype=np.bytes_, unpack=True) orig_time, columns = _read_annotations_txt_parse_header(fname) - ch_names = details = None + ch_names = extras = None if len(out) == 0: onset, duration, desc = [], [], [] else: @@ -1473,7 +1473,7 @@ def _read_annotations_txt(fname): ch_names = out[i_col] i_col += 1 if len(columns) > i_col: - details = [ + extras = [ {columns[j_col]: out[j_col][i] for j_col in range(i_col, len(columns))} for i in range(len(onset)) ] @@ -1493,7 +1493,7 @@ def _read_annotations_txt(fname): description=desc, orig_time=orig_time, ch_names=ch_names, - details=details, + extras=extras, ) return annotations @@ -1506,7 +1506,7 @@ def _read_annotations_fif(fid, tree): annotations = None else: annot_data = annot_data[0] - orig_time = ch_names = details = None + orig_time = ch_names = extras = None onset, duration, description = list(), list(), list() for ent in annot_data["directory"]: kind = ent.kind @@ -1529,12 +1529,12 @@ def _read_annotations_fif(fid, tree): elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG: ch_names = tuple(tuple(x) for x in json.loads(tag.data)) elif kind == FIFF.FIFF_FREE_LIST: - details = json.loads(tag.data) + extras = json.loads(tag.data) assert len(onset) == len(duration) == len(description) - if details is not None: - assert len(details) == len(onset) + if extras is not None: + assert len(extras) == len(onset) annotations = Annotations( - onset, duration, description, orig_time, ch_names, details + onset, duration, description, orig_time, ch_names, extras ) return annotations diff --git a/mne/epochs.py b/mne/epochs.py index b55f6f3523e..9bb304e66fe 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3577,16 +3577,16 @@ def __init__( raw, events, event_id, annotations, on_missing ) - # add the annotations.details to the metadata - if not all(d is None for d in annotations.details): + # add the annotations.extras to the metadata + if not all(d is None for d in annotations.extras): if metadata is None: - metadata = annotations.details_data_frame + metadata = annotations.extras_data_frame else: pd = _check_pandas_installed(strict=True) - details_df = annotations.details_data_frame - details_df.set_index(metadata.index, inplace=True) + extras_df = annotations.extras_data_frame + extras_df.set_index(metadata.index, inplace=True) metadata = pd.concat( - [metadata, details_df], axis=1, ignore_index=False + [metadata, extras_df], axis=1, ignore_index=False ) # call BaseEpochs constructor diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index c2a9a69424b..c3d003c8803 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -630,12 +630,12 @@ def test_annotation_epoching(): assert_equal([0, 2, 4], epochs.selection) -@pytest.mark.parametrize("with_details", [True, False]) -def test_annotation_concat(with_details): +@pytest.mark.parametrize("with_extras", [True, False]) +def test_annotation_concat(with_extras): """Test if two Annotations objects can be concatenated.""" - details = None - if with_details: - details = [ + extras = None + if with_extras: + extras = [ {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None}, None, None, @@ -646,7 +646,7 @@ def test_annotation_concat(with_details): [1, 2, 2], ["x", "y", "z"], ch_names=[[], ["3"], []], - details=details, + extras=extras, ) # test + operator (does not modify a or b) @@ -670,9 +670,9 @@ def test_annotation_concat(with_details): assert_equal(len(a), 6) assert_equal(len(b), 3) - if with_details: - all_details = [None] * 3 + details - assert all(c.details[i] == all_details[i] for i in range(len(all_details))) + if with_extras: + all_extras = [None] * 3 + extras + assert all(c.extras[i] == all_extras[i] for i in range(len(all_extras))) # test += operator (modifies a in place) b._orig_time = _handle_meas_date(1038942070.7201) @@ -981,10 +981,10 @@ def _assert_annotations_equal(a, b, tol=0): _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) -@pytest.fixture(scope="function", params=("ch_names", "fmt", "with_details")) -def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_details): +@pytest.fixture(scope="function", params=("ch_names", "fmt", "with_extras")) +def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): """Create csv file for testing.""" - details_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} + extras_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} if fmt == "csv": content = ( "onset,duration,description\n" @@ -1001,9 +1001,9 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_details): ) else: assert fmt == "fif" - details = [details_row0, None] if with_details else None + extras = [extras_row0, None] if with_extras else None content = Annotations( - [0, 9], [1, 2.425], ["AA", "BB"], orig_time=_ORIG_TIME, details=details + [0, 9], [1, 2.425], ["AA", "BB"], orig_time=_ORIG_TIME, extras=extras ) if ch_names: @@ -1017,11 +1017,11 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_details): content[-1] += ",MEG0111:MEG2563" content = "\n".join(content) - if with_details and not isinstance(content, Annotations): + if with_extras and not isinstance(content, Annotations): content = content.splitlines() - content[-3] += "," + ",".join(details_row0.keys()) - content[-2] += "," + ",".join([str(v) for v in details_row0.values()]) - content[-1] += "," * len(details_row0) + content[-3] += "," + ",".join(extras_row0.keys()) + content[-2] += "," + ",".join([str(v) for v in extras_row0.values()]) + content[-1] += "," * len(extras_row0) content = "\n".join(content) fname = tmp_path_factory.mktemp("data") / f"annotations-annot.{fmt}" @@ -1035,8 +1035,8 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_details): @pytest.mark.parametrize("ch_names", (False, True)) @pytest.mark.parametrize("fmt", [pytest.param("csv", marks=needs_pandas), "txt", "fif"]) -@pytest.mark.parametrize("with_details", [True, False]) -def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names, with_details): +@pytest.mark.parametrize("with_extras", [True, False]) +def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names, with_extras): """Test CSV, TXT, and FIF input/output (which support ch_names).""" annot = read_annotations(dummy_annotation_file) assert annot.orig_time == _ORIG_TIME @@ -1208,34 +1208,34 @@ def test_annotations_slices(): NUM_ANNOT = 5 EXPECTED_ONSETS = EXPECTED_DURATIONS = [x for x in range(NUM_ANNOT)] EXPECTED_DESCS = [x.__repr__() for x in range(NUM_ANNOT)] - DETAILS_ROW = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} - EXPECTED_DETAILS = [DETAILS_ROW] * NUM_ANNOT + EXTRAS_ROW = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} + EXPECTED_EXTRAS = [EXTRAS_ROW] * NUM_ANNOT annot = Annotations( onset=EXPECTED_ONSETS, duration=EXPECTED_DURATIONS, description=EXPECTED_DESCS, orig_time=None, - details=EXPECTED_DETAILS, + extras=EXPECTED_EXTRAS, ) # Indexing returns a copy. So this has no effect in annot annot[0]["onset"] = 42 annot[0]["duration"] = 3.14 annot[0]["description"] = "foobar" - annot[0]["details"] = DETAILS_ROW + annot[0]["extras"] = EXTRAS_ROW annot[:1].onset[0] = 42 annot[:1].duration[0] = 3.14 annot[:1].description[0] = "foobar" - annot[:1].details[0] = DETAILS_ROW + annot[:1].extras[0] = EXTRAS_ROW # Slicing with single element returns a dictionary for ii in EXPECTED_ONSETS: assert annot[ii] == dict( zip( - ["onset", "duration", "description", "orig_time", "details"], - [ii, ii, str(ii), None, DETAILS_ROW], + ["onset", "duration", "description", "orig_time", "extras"], + [ii, ii, str(ii), None, EXTRAS_ROW], ) ) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 29ee5ace846..947b84daec9 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -4917,7 +4917,7 @@ def test_add_channels_picks(): @pytest.mark.parametrize("first_samp", [0, 10]) @pytest.mark.parametrize( - "meas_date, orig_date, with_details", + "meas_date, orig_date, with_extras", [ [None, None, False], [np.pi, None, False], @@ -4925,7 +4925,7 @@ def test_add_channels_picks(): [None, None, True], ], ) -def test_epoch_annotations(first_samp, meas_date, orig_date, with_details, tmp_path): +def test_epoch_annotations(first_samp, meas_date, orig_date, with_extras, tmp_path): """Test Epoch Annotations from RawArray with dates. Tests the following cases crossed with each other: @@ -4948,14 +4948,14 @@ def test_epoch_annotations(first_samp, meas_date, orig_date, with_details, tmp_p if orig_date is not None: orig_date = meas_date + orig_date ant_dur = 0.1 - details_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} - details = [details_row0, None, None] if with_details else None + extras_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} + extras = [extras_row0, None, None] if with_extras else None ants = Annotations( onset=[1.1, 1.2, 2.1], duration=[ant_dur, ant_dur, ant_dur], description=["x", "y", "z"], orig_time=orig_date, - details=details, + extras=extras, ) raw.set_annotations(ants) epochs = make_fixed_length_epochs(raw, duration=1, overlap=0.5) @@ -4966,8 +4966,8 @@ def test_epoch_annotations(first_samp, meas_date, orig_date, with_details, tmp_p assert "annot_onset" in metadata.columns assert "annot_duration" in metadata.columns assert "annot_description" in metadata.columns - if with_details: - assert all(f"annot_{k}" in metadata.columns for k in details_row0.keys()) + if with_extras: + assert all(f"annot_{k}" in metadata.columns for k in extras_row0.keys()) # Test that writing and reading back these new metadata works temp_fname = tmp_path / "test-epo.fif" From 09bb125d9a9c92b4de128c33966fc75e775dd593 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 10:36:25 +0200 Subject: [PATCH 14/57] Fix _AnnotationsExtrasDict and add test --- mne/annotations.py | 4 ++-- mne/tests/test_annotations.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 5443423a2b5..1e3b61a6f82 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -5,7 +5,7 @@ import json import re import warnings -from collections import Counter, OrderedDict +from collections import Counter, OrderedDict, UserDict from collections.abc import Iterable from copy import deepcopy from datetime import datetime, timedelta, timezone @@ -58,7 +58,7 @@ _datetime = datetime -class _AnnotationsExtrasDict(dict): +class _AnnotationsExtrasDict(UserDict): """A dictionary for storing extra fields of annotations. The keys of the dictionary are strings, and the values can be diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index c3d003c8803..782f561a79a 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -32,6 +32,7 @@ _handle_meas_date, _read_annotations_txt_parse_header, _sync_onset, + _AnnotationsExtrasDict, ) from mne.datasets import testing from mne.io import RawArray, concatenate_raws, read_raw_fif @@ -1860,3 +1861,22 @@ def test_append_splits_boundary(tmp_path, split_size): assert len(raw.annotations) == 2 assert raw.annotations.description[0] == "BAD boundary" assert_allclose(raw.annotations.onset, [onset] * 2) + + +@pytest.mark.parametrize( + "key, value, expected_error", + ( + ("onset", 1, ValueError), # Reserved key + ("duration", 1, ValueError), # Reserved key + ("description", 1, ValueError), # Reserved key + ("ch_names", 1, ValueError), # Reserved key + ("valid_key", [], TypeError), # Invalid value type + (1, 1, TypeError), # Invalid key type + ), +) +def test_extras_dict_raises(key, value, expected_error): + extras_dict = _AnnotationsExtrasDict() + with pytest.raises(expected_error): + extras_dict[key] = value + with pytest.raises(expected_error): + extras_dict.update({key: value}) From 32641c0660299cc18054a460409af94b86b6bd28 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 11:37:32 +0200 Subject: [PATCH 15/57] fix writers --- mne/annotations.py | 28 ++++++++++++++++++++++++++-- mne/tests/test_annotations.py | 20 +++++++++++++------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 1e3b61a6f82..bee122b1854 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -353,7 +353,7 @@ def extras_columns(self) -> set[str]: def extras_data_frame(self): """The extras of the Annotations as a DataFrame.""" pd = _check_pandas_installed(strict=True) - return pd.DataFrame([d if d is not None else {} for d in self.extras]) + return pd.DataFrame([d or {} for d in self.extras]) def __eq__(self, other): """Compare to another Annotations instance.""" @@ -1203,11 +1203,22 @@ def _write_annotations(fid, annotations): fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(tuple(annotations.ch_names)) ) if any(d is not None for d in annotations.extras): - write_string(fid, FIFF.FIFF_FREE_LIST, json.dumps(annotations.extras)) + write_string( + fid, + FIFF.FIFF_FREE_LIST, + json.dumps( + [None if extra is None else extra.data for extra in annotations.extras] + ), + ) end_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) def _write_annotations_csv(fname, annot): + if len(annot.extras_columns) > 0: + warn( + "Reading extra annotation fields from CSV is not supported. " + "The extra fields will be written but not loaded when reading." + ) annot = annot.to_data_frame() if "ch_names" in annot: annot["ch_names"] = [ @@ -1232,6 +1243,19 @@ def _write_annotations_txt(fname, annot): for ci, ch in enumerate(annot.ch_names) ] ) + if len(extras_columns := annot.extras_columns) > 0: + warn( + "Reading extra annotation fields from TXT is not supported. " + "The extra fields will be written but not loaded when reading." + ) + for column in extras_columns: + content += f", {column}" + data.append( + [ + None if extra is None else extra.get(column, None) + for extra in annot.extras + ] + ) content += "\n" data = np.array(data, dtype=str).T assert data.ndim == 2 diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 782f561a79a..56989019b09 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -977,6 +977,12 @@ def _assert_annotations_equal(a, b, tol=0): a_orig_time = a.orig_time b_orig_time = b.orig_time assert a_orig_time == b_orig_time, "orig_time" + extras_columns = a.extras_columns.union(b.extras_columns) + for col in extras_columns: + for i, extra in enumerate(a.extras): + assert (extra or {}).get(col, None) == (b.extras[i] or {}).get( + col, None + ), f"extras {col} {i}" _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) @@ -985,6 +991,8 @@ def _assert_annotations_equal(a, b, tol=0): @pytest.fixture(scope="function", params=("ch_names", "fmt", "with_extras")) def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): """Create csv file for testing.""" + if with_extras and fmt!= "fif": + pytest.skip("Extras fields io are only supported in FIF format.") extras_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} if fmt == "csv": content = ( @@ -1018,13 +1026,6 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): content[-1] += ",MEG0111:MEG2563" content = "\n".join(content) - if with_extras and not isinstance(content, Annotations): - content = content.splitlines() - content[-3] += "," + ",".join(extras_row0.keys()) - content[-2] += "," + ",".join([str(v) for v in extras_row0.values()]) - content[-1] += "," * len(extras_row0) - content = "\n".join(content) - fname = tmp_path_factory.mktemp("data") / f"annotations-annot.{fmt}" if isinstance(content, str): with open(fname, "w") as f: @@ -1044,6 +1045,11 @@ def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names, with_extr kwargs = dict(orig_time=_ORIG_TIME) if ch_names: kwargs["ch_names"] = ((), ("MEG0111", "MEG2563")) + if with_extras: + kwargs["extras"] = [ + {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None}, + None, + ] _assert_annotations_equal( annot, Annotations([0.0, 9.0], [1.0, 2.425], ["AA", "BB"], **kwargs), tol=1e-6 ) From 76da62a222b8089ac1e23ffe20a35974a061fec9 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 11:41:39 +0200 Subject: [PATCH 16/57] Improve type description in docstrings --- mne/annotations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index bee122b1854..b1e25ef5af5 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -188,7 +188,7 @@ class Annotations: %(ch_names_annot)s .. versionadded:: 0.23 - extras : list[dict | None] | None + extras : list[dict[str, int | float | str | None] | None] | None Optional list fo dicts containing extra fields for each annotation. The number of items must match the number of annotations. @@ -468,7 +468,7 @@ def append(self, onset, duration, description, ch_names=None, extras=None): %(ch_names_annot)s .. versionadded:: 0.23 - extras : list[dict | None] | None + extras : list[dict[str, int | float | str | None] | None] | None Optional list of dicts containing extras fields for each annotation. The number of items must match the number of annotations. From 03e1b5a189a602e77c7efa440e33b917209a830b Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 11:54:11 +0200 Subject: [PATCH 17/57] only have a list of dict internally (no None) --- mne/annotations.py | 50 ++++++++++++----------------------- mne/tests/test_annotations.py | 4 +-- 2 files changed, 19 insertions(+), 35 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index b1e25ef5af5..6831f8101e4 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -77,24 +77,18 @@ def __setitem__(self, key: str, value: str | int | float | None) -> None: def _validate_extras(extras, length: int): _validate_type(extras, (None, list), "extras") - if extras is None: - return [None] * length - if len(extras) != length: - raise ValueError( - f"extras must be None or a list of length {length}, got {len(extras)}." - ) - for i, d in enumerate(extras): - _validate_type( - d, (dict, _AnnotationsExtrasDict, None), f"extras[{i}]", "dict or None" - ) - out = [] - for d in extras: - if d is None: - out.append(None) - else: - dd = _AnnotationsExtrasDict() - dd.update(d) - out.append(dd) + out = [_AnnotationsExtrasDict() for _ in range(length)] + if extras is not None: + if len(extras) != length: + raise ValueError( + f"extras must be None or a list of length {length}, got {len(extras)}." + ) + for i, (d, new_d) in enumerate(zip(extras, out)): + _validate_type( + d, (dict, _AnnotationsExtrasDict, None), f"extras[{i}]", "dict or None" + ) + if d is not None: + new_d.update(d) return out @@ -347,13 +341,13 @@ def extras(self, extras): @property def extras_columns(self) -> set[str]: """The set containing all the keys in all extras dicts.""" - return {k for d in self.extras if d is not None for k in d.keys()} + return {k for d in self.extras for k in d.keys()} @property def extras_data_frame(self): """The extras of the Annotations as a DataFrame.""" pd = _check_pandas_installed(strict=True) - return pd.DataFrame([d or {} for d in self.extras]) + return pd.DataFrame(self.extras) def __eq__(self, other): """Compare to another Annotations instance.""" @@ -1061,10 +1055,7 @@ def add_annotations_to_metadata(self, overwrite=False): annot_prop.append(entry) for k in extras.keys(): - entry = [ - None if annot[3] is None else annot[3].get(k, None) - for annot in epoch_annot - ] + entry = [annot[3].get(k, None) for annot in epoch_annot] extras[k].append(entry) # Create a new Annotations column that is instantiated as an empty @@ -1206,9 +1197,7 @@ def _write_annotations(fid, annotations): write_string( fid, FIFF.FIFF_FREE_LIST, - json.dumps( - [None if extra is None else extra.data for extra in annotations.extras] - ), + json.dumps([extra.data for extra in annotations.extras]), ) end_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) @@ -1250,12 +1239,7 @@ def _write_annotations_txt(fname, annot): ) for column in extras_columns: content += f", {column}" - data.append( - [ - None if extra is None else extra.get(column, None) - for extra in annot.extras - ] - ) + data.append([extra.get(column, None) for extra in annot.extras]) content += "\n" data = np.array(data, dtype=str).T assert data.ndim == 2 diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 56989019b09..749b2cee926 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -980,7 +980,7 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a.extras_columns.union(b.extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert (extra or {}).get(col, None) == (b.extras[i] or {}).get( + assert extra.get(col, None) == b.extras[i].get( col, None ), f"extras {col} {i}" @@ -991,7 +991,7 @@ def _assert_annotations_equal(a, b, tol=0): @pytest.fixture(scope="function", params=("ch_names", "fmt", "with_extras")) def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): """Create csv file for testing.""" - if with_extras and fmt!= "fif": + if with_extras and fmt != "fif": pytest.skip("Extras fields io are only supported in FIF format.") extras_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} if fmt == "csv": From dfa428f920a88d51e54dc7e047c599d017af5860 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Apr 2025 09:57:33 +0000 Subject: [PATCH 18/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_annotations.py | 8 ++++---- mne/utils/check.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 749b2cee926..078661a14c4 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -29,10 +29,10 @@ read_annotations, ) from mne.annotations import ( + _AnnotationsExtrasDict, _handle_meas_date, _read_annotations_txt_parse_header, _sync_onset, - _AnnotationsExtrasDict, ) from mne.datasets import testing from mne.io import RawArray, concatenate_raws, read_raw_fif @@ -980,9 +980,9 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a.extras_columns.union(b.extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert extra.get(col, None) == b.extras[i].get( - col, None - ), f"extras {col} {i}" + assert extra.get(col, None) == b.extras[i].get(col, None), ( + f"extras {col} {i}" + ) _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) diff --git a/mne/utils/check.py b/mne/utils/check.py index cd25577dac3..550903b6641 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -612,7 +612,9 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" ( (type(None),) if type_ is None - else (type_,) if not isinstance(type_, str) else _multi[type_] + else (type_,) + if not isinstance(type_, str) + else _multi[type_] ) for type_ in types ), @@ -625,7 +627,9 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" ( "None" if cls_ is None - else cls_.__name__ if not isinstance(cls_, str) else cls_ + else cls_.__name__ + if not isinstance(cls_, str) + else cls_ ) for cls_ in types ] From 7730c2f789ce044e16ff4cd39bfcf2c79a38f9e3 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 11:59:11 +0200 Subject: [PATCH 19/57] Update test --- mne/tests/test_annotations.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 749b2cee926..1b3b24142cc 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -673,7 +673,14 @@ def test_annotation_concat(with_extras): if with_extras: all_extras = [None] * 3 + extras - assert all(c.extras[i] == all_extras[i] for i in range(len(all_extras))) + assert all( + ( + c.extras[i] == all_extras[i] + if all_extras[i] is not None + else len(c.extras[i]) == 0 + ) + for i in range(len(all_extras)) + ) # test += operator (modifies a in place) b._orig_time = _handle_meas_date(1038942070.7201) From 9b43610b63d65bbade7b7029984c2a7f88ab9bdb Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 12:24:10 +0200 Subject: [PATCH 20/57] Add missing docstring --- mne/tests/test_annotations.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 8ceef621997..b64816e5f67 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -987,9 +987,9 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a.extras_columns.union(b.extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert extra.get(col, None) == b.extras[i].get(col, None), ( - f"extras {col} {i}" - ) + assert extra.get(col, None) == b.extras[i].get( + col, None + ), f"extras {col} {i}" _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) @@ -1888,6 +1888,7 @@ def test_append_splits_boundary(tmp_path, split_size): ), ) def test_extras_dict_raises(key, value, expected_error): + """Test that _AnnotationsExtrasDict raises errors for invalid keys/values.""" extras_dict = _AnnotationsExtrasDict() with pytest.raises(expected_error): extras_dict[key] = value From dceb48988544fdae2ea66e98f32975b83f69b9a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Apr 2025 10:42:52 +0000 Subject: [PATCH 21/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_annotations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index b64816e5f67..6bfe425ca14 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -987,9 +987,9 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a.extras_columns.union(b.extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert extra.get(col, None) == b.extras[i].get( - col, None - ), f"extras {col} {i}" + assert extra.get(col, None) == b.extras[i].get(col, None), ( + f"extras {col} {i}" + ) _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) From b2c89e996870670dd8338799886d4dabf152acc1 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Date: Sat, 26 Apr 2025 16:15:01 +0200 Subject: [PATCH 22/57] Apply review suggestion Co-authored-by: Alexandre Gramfort --- doc/changes/devel/13228.newfeature.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/devel/13228.newfeature.rst b/doc/changes/devel/13228.newfeature.rst index 2681654210d..7cd01dffa8e 100644 --- a/doc/changes/devel/13228.newfeature.rst +++ b/doc/changes/devel/13228.newfeature.rst @@ -1 +1 @@ -Add a ``details`` attribute to :class:`mne.Annotations`, by `Pierre Guetschel`_. \ No newline at end of file +Add an ``extras`` attribute to :class:`mne.Annotations`, by `Pierre Guetschel`_. \ No newline at end of file From 3f2003ad2ed6c4e9862dd62c4e8f182cabe3de6a Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 26 Apr 2025 18:37:17 +0200 Subject: [PATCH 23/57] Fix tests --- mne/annotations.py | 15 +++++++++++---- mne/epochs.py | 2 +- mne/tests/test_epochs.py | 14 +++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 6831f8101e4..d6ceade6a88 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -908,7 +908,7 @@ def set_annotations(self, annotations, on_missing="raise", *, verbose=None): self._annotations = new_annotations return self - def get_annotations_per_epoch(self): + def get_annotations_per_epoch(self, with_extras=False): """Get a list of annotations that occur during each epoch. Returns @@ -920,6 +920,9 @@ def get_annotations_per_epoch(self): duration, description (not as a :class:`~mne.Annotations` object), where the onset is now relative to time=0 of the epoch, rather than time=0 of the original continuous (raw) data. + with_extras : bool + Whether to include the annotations extra fields in the output, + as an additional last element of the tuple. Default is False. """ # create a list of annotations for each epoch epoch_annot_list = [[] for _ in range(len(self.events))] @@ -978,13 +981,14 @@ def get_annotations_per_epoch(self): this_annot["onset"] - this_tzero, this_annot["duration"], this_annot["description"], - this_annot["extras"], ) + if with_extras: + annot += (this_annot["extras"],) # ...then add it to the correct sublist of `epoch_annot_list` epoch_annot_list[epo_ix].append(annot) return epoch_annot_list - def add_annotations_to_metadata(self, overwrite=False): + def add_annotations_to_metadata(self, overwrite=False, with_extras=True): """Add raw annotations into the Epochs metadata data frame. Adds three columns to the ``metadata`` consisting of a list @@ -1001,6 +1005,9 @@ def add_annotations_to_metadata(self, overwrite=False): overwrite : bool Whether to overwrite existing columns in metadata or not. Default is False. + with_extras : bool + Whether to include the annotations extra fields in the output, + as an additional last element of the tuple. Default is True. Returns ------- @@ -1042,7 +1049,7 @@ def add_annotations_to_metadata(self, overwrite=False): # get the Epoch annotations, then convert to separate lists for # onsets, durations, and descriptions - epoch_annot_list = self.get_annotations_per_epoch() + epoch_annot_list = self.get_annotations_per_epoch(with_extras=with_extras) onset, duration, description = [], [], [] extras = {k: [] for k in self.annotations.extras_columns} for epoch_annot in epoch_annot_list: diff --git a/mne/epochs.py b/mne/epochs.py index 9bb304e66fe..605163e3656 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3578,7 +3578,7 @@ def __init__( ) # add the annotations.extras to the metadata - if not all(d is None for d in annotations.extras): + if not all(len(d) == 0 for d in annotations.extras): if metadata is None: metadata = annotations.extras_data_frame else: diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 947b84daec9..f68fea86cb5 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -479,12 +479,12 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True assert isinstance(drop_log, tuple), "drop_log should be tuple" - assert all(isinstance(log, tuple) for log in drop_log), ( - "drop_log[ii] should be tuple" - ) - assert all(isinstance(s, str) for log in drop_log for s in log), ( - "drop_log[ii][jj] should be str" - ) + assert all( + isinstance(log, tuple) for log in drop_log + ), "drop_log[ii] should be tuple" + assert all( + isinstance(s, str) for log in drop_log for s in log + ), "drop_log[ii][jj] should be str" def test_reject(): @@ -4961,7 +4961,7 @@ def test_epoch_annotations(first_samp, meas_date, orig_date, with_extras, tmp_pa epochs = make_fixed_length_epochs(raw, duration=1, overlap=0.5) # add Annotations to Epochs metadata - epochs.add_annotations_to_metadata() + epochs.add_annotations_to_metadata(with_extras=with_extras) metadata = epochs.metadata assert "annot_onset" in metadata.columns assert "annot_duration" in metadata.columns From 1d2b30c4759002d3cbdeeadfcc8a48bf0ee11205 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Apr 2025 16:37:37 +0000 Subject: [PATCH 24/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_epochs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index f68fea86cb5..fc1a95c9ba3 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -479,12 +479,12 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True assert isinstance(drop_log, tuple), "drop_log should be tuple" - assert all( - isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple" - assert all( - isinstance(s, str) for log in drop_log for s in log - ), "drop_log[ii][jj] should be str" + assert all(isinstance(log, tuple) for log in drop_log), ( + "drop_log[ii] should be tuple" + ) + assert all(isinstance(s, str) for log in drop_log for s in log), ( + "drop_log[ii][jj] should be str" + ) def test_reject(): From 1b2b05df110084ad0ac15034df6fdcbb68266f6d Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sun, 27 Apr 2025 11:06:43 +0200 Subject: [PATCH 25/57] Fix docstrings --- mne/annotations.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index d6ceade6a88..bfe7bb5ad3e 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -186,7 +186,7 @@ class Annotations: Optional list fo dicts containing extra fields for each annotation. The number of items must match the number of annotations. - .. versionadded:: 1.10.0 + .. versionadded:: 1.10 See Also -------- @@ -466,7 +466,7 @@ def append(self, onset, duration, description, ch_names=None, extras=None): Optional list of dicts containing extras fields for each annotation. The number of items must match the number of annotations. - .. versionadded:: 1.10.0 + .. versionadded:: 1.10 Returns ------- @@ -911,6 +911,14 @@ def set_annotations(self, annotations, on_missing="raise", *, verbose=None): def get_annotations_per_epoch(self, with_extras=False): """Get a list of annotations that occur during each epoch. + Parameters + ---------- + with_extras : bool + Whether to include the annotations extra fields in the output, + as an additional last element of the tuple. Default is False. + + .. versionadded:: 1.10 + Returns ------- epoch_annots : list @@ -920,9 +928,6 @@ def get_annotations_per_epoch(self, with_extras=False): duration, description (not as a :class:`~mne.Annotations` object), where the onset is now relative to time=0 of the epoch, rather than time=0 of the original continuous (raw) data. - with_extras : bool - Whether to include the annotations extra fields in the output, - as an additional last element of the tuple. Default is False. """ # create a list of annotations for each epoch epoch_annot_list = [[] for _ in range(len(self.events))] @@ -1009,6 +1014,8 @@ def add_annotations_to_metadata(self, overwrite=False, with_extras=True): Whether to include the annotations extra fields in the output, as an additional last element of the tuple. Default is True. + .. versionadded:: 1.10 + Returns ------- self : instance of Epochs From a94dc4f2cc0578e1b9cf666129725b2a598828df Mon Sep 17 00:00:00 2001 From: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Date: Thu, 1 May 2025 11:24:37 +0200 Subject: [PATCH 26/57] Apply suggestions from code review Co-authored-by: Daniel McCloy --- doc/changes/devel/13228.newfeature.rst | 2 +- mne/annotations.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/changes/devel/13228.newfeature.rst b/doc/changes/devel/13228.newfeature.rst index 7cd01dffa8e..a242762d2f6 100644 --- a/doc/changes/devel/13228.newfeature.rst +++ b/doc/changes/devel/13228.newfeature.rst @@ -1 +1 @@ -Add an ``extras`` attribute to :class:`mne.Annotations`, by `Pierre Guetschel`_. \ No newline at end of file +Add an ``extras`` attribute to :class:`mne.Annotations` for storing arbitrary metadata, by `Pierre Guetschel`_. \ No newline at end of file diff --git a/mne/annotations.py b/mne/annotations.py index bfe7bb5ad3e..6cdd41a23c6 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -72,7 +72,7 @@ def __setitem__(self, key: str, value: str | int | float | None) -> None: _validate_type( value, (str, int, float, None), "value", "string, int, float or None" ) - return super().__setitem__(key, value) + super().__setitem__(key, value) def _validate_extras(extras, length: int): @@ -183,7 +183,7 @@ class Annotations: .. versionadded:: 0.23 extras : list[dict[str, int | float | str | None] | None] | None - Optional list fo dicts containing extra fields for each annotation. + Optional list of dicts containing extra fields for each annotation. The number of items must match the number of annotations. .. versionadded:: 1.10 @@ -516,6 +516,7 @@ def delete(self, idx): if isinstance(idx, int_like): del self.extras[idx] elif len(idx) > 0: + # convert slice-like idx to ints, and delete list items in reverse order for i in np.sort(np.arange(len(self.extras))[idx])[::-1]: del self.extras[i] From ab0f067109953c1297690546dd8551c8fafde97b Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 1 May 2025 11:53:42 +0200 Subject: [PATCH 27/57] Improve `test_extras_dict_raises` --- mne/tests/test_annotations.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6bfe425ca14..d09760a19f3 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1877,20 +1877,25 @@ def test_append_splits_boundary(tmp_path, split_size): @pytest.mark.parametrize( - "key, value, expected_error", + "key, value, expected_error, match", ( - ("onset", 1, ValueError), # Reserved key - ("duration", 1, ValueError), # Reserved key - ("description", 1, ValueError), # Reserved key - ("ch_names", 1, ValueError), # Reserved key - ("valid_key", [], TypeError), # Invalid value type - (1, 1, TypeError), # Invalid key type + ("onset", 1, ValueError, "reserved"), + ("duration", 1, ValueError, "reserved"), + ("description", 1, ValueError, "reserved"), + ("ch_names", 1, ValueError, "reserved"), + ("valid_key", [], TypeError, "value must be an instance of"), + (1, 1, TypeError, "key must be an instance of"), ), ) -def test_extras_dict_raises(key, value, expected_error): +def test_extras_dict_raises(key, value, expected_error, match): """Test that _AnnotationsExtrasDict raises errors for invalid keys/values.""" extras_dict = _AnnotationsExtrasDict() - with pytest.raises(expected_error): + with pytest.raises(expected_error, match=match): extras_dict[key] = value - with pytest.raises(expected_error): + with pytest.raises(expected_error, match=match): extras_dict.update({key: value}) + with pytest.raises(expected_error, match=match): + _AnnotationsExtrasDict({key: value}) + if isinstance(key, str): + with pytest.raises(expected_error, match=match): + _AnnotationsExtrasDict(**{key: value}) From af8add630f6a63e05d45a2c312987d19c2b2fe78 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 1 May 2025 12:06:48 +0200 Subject: [PATCH 28/57] Make extras_columns private --- mne/annotations.py | 8 ++++---- mne/tests/test_annotations.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 6cdd41a23c6..19ac5fc21a1 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -339,7 +339,7 @@ def extras(self, extras): self._extras = _validate_extras(extras, len(self.onset)) @property - def extras_columns(self) -> set[str]: + def _extras_columns(self) -> set[str]: """The set containing all the keys in all extras dicts.""" return {k for d in self.extras for k in d.keys()} @@ -1059,7 +1059,7 @@ def add_annotations_to_metadata(self, overwrite=False, with_extras=True): # onsets, durations, and descriptions epoch_annot_list = self.get_annotations_per_epoch(with_extras=with_extras) onset, duration, description = [], [], [] - extras = {k: [] for k in self.annotations.extras_columns} + extras = {k: [] for k in self.annotations._extras_columns} for epoch_annot in epoch_annot_list: for ix, annot_prop in enumerate((onset, duration, description)): entry = [annot[ix] for annot in epoch_annot] @@ -1218,7 +1218,7 @@ def _write_annotations(fid, annotations): def _write_annotations_csv(fname, annot): - if len(annot.extras_columns) > 0: + if len(annot._extras_columns) > 0: warn( "Reading extra annotation fields from CSV is not supported. " "The extra fields will be written but not loaded when reading." @@ -1247,7 +1247,7 @@ def _write_annotations_txt(fname, annot): for ci, ch in enumerate(annot.ch_names) ] ) - if len(extras_columns := annot.extras_columns) > 0: + if len(extras_columns := annot._extras_columns) > 0: warn( "Reading extra annotation fields from TXT is not supported. " "The extra fields will be written but not loaded when reading." diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index d09760a19f3..bdc96294e59 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -984,7 +984,7 @@ def _assert_annotations_equal(a, b, tol=0): a_orig_time = a.orig_time b_orig_time = b.orig_time assert a_orig_time == b_orig_time, "orig_time" - extras_columns = a.extras_columns.union(b.extras_columns) + extras_columns = a._extras_columns.union(b._extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): assert extra.get(col, None) == b.extras[i].get(col, None), ( From f9246237677864beecab631b7237c549d348093d Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 1 May 2025 12:12:21 +0200 Subject: [PATCH 29/57] Remove extras_data_frame attribute --- mne/annotations.py | 16 ++++++++-------- mne/epochs.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 19ac5fc21a1..28ecf5e2c95 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -331,7 +331,12 @@ def orig_time(self): @property def extras(self): - """The extras of the Annotations.""" + """The extras of the Annotations. + + The ``extras`` attribute is a list of dictionaries. + It can easily be converted to a pandas DataFrame using: + ``pd.DataFrame(extras)``. + """ return self._extras @extras.setter @@ -343,12 +348,6 @@ def _extras_columns(self) -> set[str]: """The set containing all the keys in all extras dicts.""" return {k for d in self.extras for k in d.keys()} - @property - def extras_data_frame(self): - """The extras of the Annotations as a DataFrame.""" - pd = _check_pandas_installed(strict=True) - return pd.DataFrame(self.extras) - def __eq__(self, other): """Compare to another Annotations instance.""" if not isinstance(other, Annotations): @@ -549,7 +548,8 @@ def to_data_frame(self, time_format="datetime"): if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) - df = pd.concat([df, self.extras_data_frame], axis=1) + extras_df = pd.DataFrame(self.extras) + df = pd.concat([df, extras_df], axis=1) return df def count(self): diff --git a/mne/epochs.py b/mne/epochs.py index 605163e3656..c042715e6ae 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3579,11 +3579,11 @@ def __init__( # add the annotations.extras to the metadata if not all(len(d) == 0 for d in annotations.extras): + pd = _check_pandas_installed(strict=True) + extras_df = pd.DataFrame(annotations.extras) if metadata is None: - metadata = annotations.extras_data_frame + metadata = extras_df else: - pd = _check_pandas_installed(strict=True) - extras_df = annotations.extras_data_frame extras_df.set_index(metadata.index, inplace=True) metadata = pd.concat( [metadata, extras_df], axis=1, ignore_index=False From 088858da7379ca678c66e442210da7c674e7a3be Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 1 May 2025 15:11:18 +0200 Subject: [PATCH 30/57] Support saving to csv and txt --- mne/annotations.py | 54 ++++++++++++++++++++++++++--------- mne/tests/test_annotations.py | 16 +++++++---- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 28ecf5e2c95..c8cd061adcf 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1218,11 +1218,6 @@ def _write_annotations(fid, annotations): def _write_annotations_csv(fname, annot): - if len(annot._extras_columns) > 0: - warn( - "Reading extra annotation fields from CSV is not supported. " - "The extra fields will be written but not loaded when reading." - ) annot = annot.to_data_frame() if "ch_names" in annot: annot["ch_names"] = [ @@ -1238,8 +1233,10 @@ def _write_annotations_txt(fname, annot): # for backward compat, we do not write tzinfo (assumed UTC) content += f"# orig_time : {annot.orig_time.replace(tzinfo=None)}\n" content += "# onset, duration, description" + n_cols = 3 data = [annot.onset, annot.duration, annot.description] if annot._any_ch_names(): + n_cols += 1 content += ", ch_names" data.append( [ @@ -1248,18 +1245,20 @@ def _write_annotations_txt(fname, annot): ] ) if len(extras_columns := annot._extras_columns) > 0: - warn( - "Reading extra annotation fields from TXT is not supported. " - "The extra fields will be written but not loaded when reading." - ) + n_cols += len(extras_columns) for column in extras_columns: content += f", {column}" - data.append([extra.get(column, None) for extra in annot.extras]) + data.append( + [ + val if (val := extra.get(column, None)) is not None else "" + for extra in annot.extras + ] + ) content += "\n" data = np.array(data, dtype=str).T assert data.ndim == 2 assert data.shape[0] == len(annot.onset) - assert data.shape[1] in (3, 4) + assert data.shape[1] == n_cols with open(fname, "wb") as fid: fid.write(content.encode()) np.savetxt(fid, data, delimiter=",", fmt="%s") @@ -1366,6 +1365,20 @@ def read_annotations( return annotations +def _cast_extras_types(val): + """Cast types to int or float.""" + if val == "": + return None + try: + out = int(val) + except (ValueError, TypeError): + try: + out = float(val) + except (ValueError, TypeError): + out = val + return out + + def _read_annotations_csv(fname): """Read annotations from csv. @@ -1402,7 +1415,19 @@ def _read_annotations_csv(fname): _safe_name_list(val, "read", "annotation channel name") for val in df["ch_names"].values ] - return Annotations(onset, duration, description, orig_time, ch_names) + other_columns = list( + df.columns.difference(["onset", "duration", "description", "ch_names"]) + ) + extras = None + if len(other_columns) > 0: + extras = df[other_columns].astype(object).to_dict(orient="records") + # if we try to cast the types within the pandas dataframe, + # it will fail if the column contains mixed types + extras = [ + {k: _cast_extras_types(v) for k, v in extra.items()} for extra in extras + ] + print(extras) + return Annotations(onset, duration, description, orig_time, ch_names, extras) def _read_brainstorm_annotations(fname, orig_time=None): @@ -1497,7 +1522,10 @@ def _read_annotations_txt(fname): i_col += 1 if len(columns) > i_col: extras = [ - {columns[j_col]: out[j_col][i] for j_col in range(i_col, len(columns))} + { + columns[j_col]: _cast_extras_types(out[j_col][i].decode("UTF-8")) + for j_col in range(i_col, len(columns)) + } for i in range(len(onset)) ] diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index bdc96294e59..805a5beb19c 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -987,9 +987,9 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a._extras_columns.union(b._extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert extra.get(col, None) == b.extras[i].get(col, None), ( - f"extras {col} {i}" - ) + assert extra.get(col, None) == b.extras[i].get( + col, None + ), f"extras[{i}][{col}]" _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) @@ -998,8 +998,6 @@ def _assert_annotations_equal(a, b, tol=0): @pytest.fixture(scope="function", params=("ch_names", "fmt", "with_extras")) def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): """Create csv file for testing.""" - if with_extras and fmt != "fif": - pytest.skip("Extras fields io are only supported in FIF format.") extras_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} if fmt == "csv": content = ( @@ -1032,6 +1030,14 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): content[-2] += "," content[-1] += ",MEG0111:MEG2563" content = "\n".join(content) + if with_extras and fmt != "fif": + content = content.splitlines() + content[-3] += "," + ",".join(extras_row0.keys()) + content[-2] += "," + ",".join( + ["" if v is None else str(v) for v in extras_row0.values()] + ) + content[-1] += ",,,," + content = "\n".join(content) fname = tmp_path_factory.mktemp("data") / f"annotations-annot.{fmt}" if isinstance(content, str): From 74a480a86c54b04a9270f9b75a5db070a22d4577 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 1 May 2025 15:25:09 +0200 Subject: [PATCH 31/57] simplify assert --- mne/tests/test_annotations.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 805a5beb19c..6a032c8cf8d 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -672,15 +672,8 @@ def test_annotation_concat(with_extras): assert_equal(len(b), 3) if with_extras: - all_extras = [None] * 3 + extras - assert all( - ( - c.extras[i] == all_extras[i] - if all_extras[i] is not None - else len(c.extras[i]) == 0 - ) - for i in range(len(all_extras)) - ) + all_extras = [extra or {} for extra in [None] * 3 + extras] + assert all(c.extras[i] == all_extras[i] for i in range(len(all_extras))) # test += operator (modifies a in place) b._orig_time = _handle_meas_date(1038942070.7201) From 2f5a04e5e936ce0d78437c6b638ce37f87fd4a84 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 1 May 2025 15:44:25 +0200 Subject: [PATCH 32/57] Simplify read txt --- mne/annotations.py | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index c8cd061adcf..57280be5619 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1415,12 +1415,12 @@ def _read_annotations_csv(fname): _safe_name_list(val, "read", "annotation channel name") for val in df["ch_names"].values ] - other_columns = list( + extra_columns = list( df.columns.difference(["onset", "duration", "description", "ch_names"]) ) extras = None - if len(other_columns) > 0: - extras = df[other_columns].astype(object).to_dict(orient="records") + if len(extra_columns) > 0: + extras = df[extra_columns].astype(object).to_dict(orient="records") # if we try to cast the types within the pandas dataframe, # it will fail if the column contains mixed types extras = [ @@ -1506,25 +1506,41 @@ def _read_annotations_txt(fname): onset, duration, desc = [], [], [] else: if columns is None: + # No column names were present in the header + # We assume the first three columns are onset, duration, description + # And eventually a fourth column with ch_names _check_option("text header", len(out), (3, 4)) columns = ["onset", "duration", "description"] + ( ["ch_names"] if len(out) == 4 else [] ) - else: - _check_option( - "text header", columns[:3], (["onset", "duration", "description"],) + col_map = {col: i for i, col in enumerate(columns)} + if len(col_map) != len(columns): + raise ValueError( + "Duplicate column names found in header. Please check the file format." + ) + if missing := {"onset", "duration", "description"} - set(col_map.keys()): + raise ValueError( + f"Column(s) {missing} not found in header. Please check the file format." ) - _check_option("text header len", len(out), (len(columns),)) - onset, duration, desc = out[:3] - i_col = 3 - if len(columns) > i_col and columns[i_col] == "ch_names": - ch_names = out[i_col] - i_col += 1 - if len(columns) > i_col: + _check_option("text header len", len(out), (len(columns),)) + onset = out[col_map["onset"]] + duration = out[col_map["duration"]] + desc = out[col_map["description"]] + if "ch_names" in col_map: + ch_names = out[col_map["ch_names"]] + extra_columns = set(col_map.keys()) - { + "onset", + "duration", + "description", + "ch_names", + } + if extra_columns: extras = [ { - columns[j_col]: _cast_extras_types(out[j_col][i].decode("UTF-8")) - for j_col in range(i_col, len(columns)) + col_name: _cast_extras_types( + out[col_map[col_name]][i].decode("UTF-8") + ) + for col_name in extra_columns } for i in range(len(onset)) ] From f2a6b7676376aa0d7a7c85909cbc60e5943f187a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 May 2025 13:45:22 +0000 Subject: [PATCH 33/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_annotations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6a032c8cf8d..f6b616f18f2 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -980,9 +980,9 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a._extras_columns.union(b._extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert extra.get(col, None) == b.extras[i].get( - col, None - ), f"extras[{i}][{col}]" + assert extra.get(col, None) == b.extras[i].get(col, None), ( + f"extras[{i}][{col}]" + ) _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) From 061eb582b88f7753343b43e73c0a1e58083f6905 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 1 May 2025 15:47:16 +0200 Subject: [PATCH 34/57] pre-commit --- mne/annotations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/annotations.py b/mne/annotations.py index 57280be5619..8c412c79cb1 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1520,7 +1520,8 @@ def _read_annotations_txt(fname): ) if missing := {"onset", "duration", "description"} - set(col_map.keys()): raise ValueError( - f"Column(s) {missing} not found in header. Please check the file format." + f"Column(s) {missing} not found in header. " + "Please check the file format." ) _check_option("text header len", len(out), (len(columns),)) onset = out[col_map["onset"]] From d948ca81c7a0d7718eace6ab7a1e16d741a8e26a Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Fri, 2 May 2025 15:09:52 +0200 Subject: [PATCH 35/57] Add _AnnotationsExtrasList container --- mne/annotations.py | 94 +++++++++++++++++++++++++++++------ mne/tests/test_annotations.py | 38 ++++++++++++-- 2 files changed, 114 insertions(+), 18 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 8c412c79cb1..8739b77bb0d 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -5,7 +5,7 @@ import json import re import warnings -from collections import Counter, OrderedDict, UserDict +from collections import Counter, OrderedDict, UserDict, UserList from collections.abc import Iterable from copy import deepcopy from datetime import datetime, timedelta, timezone @@ -75,21 +75,85 @@ def __setitem__(self, key: str, value: str | int | float | None) -> None: super().__setitem__(key, value) +class _AnnotationsExtrasList(UserList): + """A list of dictionaries for storing extra fields of annotations. + + Each dictionary in the list corresponds to an annotation and contains + extra fields. + The keys of the dictionaries are strings, and the values can be + strings, integers, floats, or None. + """ + + @staticmethod + def _validate_value( + value: dict | _AnnotationsExtrasDict | None, + ) -> _AnnotationsExtrasDict: + _validate_type( + value, + (dict, _AnnotationsExtrasDict, None), + "extras dict value", + "dict or None", + ) + return ( + value + if isinstance(value, _AnnotationsExtrasDict) + else _AnnotationsExtrasDict(value or {}) + ) + + def __init__(self, initlist=None): + if not (isinstance(initlist, _AnnotationsExtrasList) or initlist is None): + initlist = [self._validate_value(v) for v in initlist] + super().__init__(initlist) + + def __setitem__( # type: ignore[override] + self, + key: int | slice, + value: ( + dict + | _AnnotationsExtrasDict + | None + | Iterable[dict | _AnnotationsExtrasDict | None] + ), + ) -> None: + _validate_type(key, (int, slice), "key", "int or slice") + if isinstance(key, int): + iterable = False + value = [value] + else: + _validate_type(value, Iterable, "value", "Iterable when key is a slice") + iterable = True + + new_values = [self._validate_value(v) for v in value] + if not iterable: + new_values = new_values[0] + super().__setitem__(key, new_values) + + def __iadd__(self, other): + if not isinstance(other, _AnnotationsExtrasList): + other = _AnnotationsExtrasList(other) + super().__iadd__(other) + + def append(self, item): + super().append(self._validate_value(item)) + + def insert(self, i, item): + super().insert(i, self._validate_value(item)) + + def extend(self, other): + if not isinstance(other, _AnnotationsExtrasList): + other = _AnnotationsExtrasList(other) + super().extend(other) + + def _validate_extras(extras, length: int): - _validate_type(extras, (None, list), "extras") - out = [_AnnotationsExtrasDict() for _ in range(length)] - if extras is not None: - if len(extras) != length: - raise ValueError( - f"extras must be None or a list of length {length}, got {len(extras)}." - ) - for i, (d, new_d) in enumerate(zip(extras, out)): - _validate_type( - d, (dict, _AnnotationsExtrasDict, None), f"extras[{i}]", "dict or None" - ) - if d is not None: - new_d.update(d) - return out + _validate_type(extras, (None, list, _AnnotationsExtrasList), "extras") + if extras is not None and len(extras) != length: + raise ValueError( + f"extras must be None or a list of length {length}, got {len(extras)}." + ) + if isinstance(extras, _AnnotationsExtrasList): + return extras + return _AnnotationsExtrasList(extras or [None] * length) def _check_o_d_s_c_e(onset, duration, description, ch_names, extras): diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index f6b616f18f2..358f7b3b230 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -30,6 +30,7 @@ ) from mne.annotations import ( _AnnotationsExtrasDict, + _AnnotationsExtrasList, _handle_meas_date, _read_annotations_txt_parse_header, _sync_onset, @@ -980,9 +981,9 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a._extras_columns.union(b._extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert extra.get(col, None) == b.extras[i].get(col, None), ( - f"extras[{i}][{col}]" - ) + assert extra.get(col, None) == b.extras[i].get( + col, None + ), f"extras[{i}][{col}]" _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) @@ -1898,3 +1899,34 @@ def test_extras_dict_raises(key, value, expected_error, match): if isinstance(key, str): with pytest.raises(expected_error, match=match): _AnnotationsExtrasDict(**{key: value}) + + +@pytest.mark.parametrize( + "key, value, expected_error, match", + ( + ("onset", 1, ValueError, "reserved"), + ("duration", 1, ValueError, "reserved"), + ("description", 1, ValueError, "reserved"), + ("ch_names", 1, ValueError, "reserved"), + ("valid_key", [], TypeError, "value must be an instance of"), + (1, 1, TypeError, "key must be an instance of"), + ), +) +def test_extras_list_raises(key, value, expected_error, match): + """Test that _AnnotationsExtrasList raises errors for invalid keys/values.""" + extras = _AnnotationsExtrasList([None]) + assert all(isinstance(extra, _AnnotationsExtrasDict) for extra in extras) + with pytest.raises(expected_error, match=match): + extras[0] = {key: value} + with pytest.raises(expected_error, match=match): + extras[:1] = [{key: value}] + with pytest.raises(expected_error, match=match): + extras[0].update({key: value}) + with pytest.raises(expected_error, match=match): + _AnnotationsExtrasList([{key: value}]) + with pytest.raises(expected_error, match=match): + extras.append({key: value}) + with pytest.raises(expected_error, match=match): + extras.extend([{key: value}]) + with pytest.raises(expected_error, match=match): + extras += [{key: value}] From d5e823c755e0405e4c91b7dc5dd40ae8c4cbcbe6 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 09:57:34 +0200 Subject: [PATCH 36/57] simplify read CSV --- mne/annotations.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 8739b77bb0d..2eaf276835f 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1446,6 +1446,13 @@ def _cast_extras_types(val): def _read_annotations_csv(fname): """Read annotations from csv. + The dtypes of the extra fields will automatically be infered + by pandas. If some fields have heterogeneous types on the + different rows, this automatic inference may return unexpecterd + types. + If you need to save heterogeneous extra dtypes, we recomend + saving to FIF. + Parameters ---------- fname : path-like @@ -1484,13 +1491,7 @@ def _read_annotations_csv(fname): ) extras = None if len(extra_columns) > 0: - extras = df[extra_columns].astype(object).to_dict(orient="records") - # if we try to cast the types within the pandas dataframe, - # it will fail if the column contains mixed types - extras = [ - {k: _cast_extras_types(v) for k, v in extra.items()} for extra in extras - ] - print(extras) + extras = df[extra_columns].to_dict(orient="records") return Annotations(onset, duration, description, orig_time, ch_names, extras) From ad0baa0cb749aae10d39c4d50ca875154260b169 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 10:21:43 +0200 Subject: [PATCH 37/57] Update test --- mne/tests/test_annotations.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 358f7b3b230..cf72057e2ee 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -969,7 +969,7 @@ def _constant_id(*args, **kwargs): # Test for IO with .csv files -def _assert_annotations_equal(a, b, tol=0): +def _assert_annotations_equal(a, b, tol=0, comp_extras_as_str=False): __tracebackhide__ = True assert_allclose(a.onset, b.onset, rtol=0, atol=tol, err_msg="onset") assert_allclose(a.duration, b.duration, rtol=0, atol=tol, err_msg="duration") @@ -981,9 +981,12 @@ def _assert_annotations_equal(a, b, tol=0): extras_columns = a._extras_columns.union(b._extras_columns) for col in extras_columns: for i, extra in enumerate(a.extras): - assert extra.get(col, None) == b.extras[i].get( - col, None - ), f"extras[{i}][{col}]" + exa = extra.get(col, None) + exb = b.extras[i].get(col, None) + if comp_extras_as_str: + exa = str(exa) if exa is not None else "" + exb = str(exb) if exb is not None else "" + assert exa == exb, f"extras[{i}][{col}]" _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) @@ -1058,7 +1061,10 @@ def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names, with_extr None, ] _assert_annotations_equal( - annot, Annotations([0.0, 9.0], [1.0, 2.425], ["AA", "BB"], **kwargs), tol=1e-6 + annot, + Annotations([0.0, 9.0], [1.0, 2.425], ["AA", "BB"], **kwargs), + tol=1e-6, + comp_extras_as_str=fmt in ["csv", "txt"], ) # Now test writing From 9e55e70c9aae563bcff28a26930ad9f343f26b50 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 10:22:10 +0200 Subject: [PATCH 38/57] Warn when writing heterogeneous csv extras --- mne/annotations.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mne/annotations.py b/mne/annotations.py index 2eaf276835f..606f4d75006 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1288,6 +1288,18 @@ def _write_annotations_csv(fname, annot): _safe_name_list(ch, "write", name=f'annot["ch_names"][{ci}') for ci, ch in enumerate(annot["ch_names"]) ] + extras_columns = set(annot.columns) - { + "onset", + "duration", + "description", + "ch_names", + } + for col in extras_columns: + if len(dtypes := annot[col].apply(type).unique()) > 1: + warn( + f"Extra field '{col}' contains heterogeneous dtypes ({dtypes}). " + "Loading these CSV annotations may not return the original dtypes." + ) annot.to_csv(fname, index=False) From cd9f08ff754b8064286c366934c193bb7f8994b5 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 11:10:26 +0200 Subject: [PATCH 39/57] Warn when writing heterogenous dtypes in txt --- mne/annotations.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 606f4d75006..1f03f1ab38c 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1298,7 +1298,7 @@ def _write_annotations_csv(fname, annot): if len(dtypes := annot[col].apply(type).unique()) > 1: warn( f"Extra field '{col}' contains heterogeneous dtypes ({dtypes}). " - "Loading these CSV annotations may not return the original dtypes." + "Loading these CSV annotations may not return the original dtypes." ) annot.to_csv(fname, index=False) @@ -1324,12 +1324,13 @@ def _write_annotations_txt(fname, annot): n_cols += len(extras_columns) for column in extras_columns: content += f", {column}" - data.append( - [ - val if (val := extra.get(column, None)) is not None else "" - for extra in annot.extras - ] - ) + values = [extra.get(column, None) for extra in annot.extras] + if len(dtypes := set(type(v) for v in values)) > 1: + warn( + f"Extra field '{column}' contains heterogeneous dtypes ({dtypes}). " + "Loading these TXT annotations may not return the original dtypes." + ) + data.append([val if val is not None else "" for val in values]) content += "\n" data = np.array(data, dtype=str).T assert data.ndim == 2 From 42447bc5f284653e419e5da21aa2ae459684df96 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 11:10:32 +0200 Subject: [PATCH 40/57] test warnings --- mne/tests/test_annotations.py | 45 +++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index cf72057e2ee..c9110bcaee0 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1045,6 +1045,7 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): return fname +@pytest.mark.filterwarnings("ignore:.*heterogeneous dtypes.*") @pytest.mark.parametrize("ch_names", (False, True)) @pytest.mark.parametrize("fmt", [pytest.param("csv", marks=needs_pandas), "txt", "fif"]) @pytest.mark.parametrize("with_extras", [True, False]) @@ -1080,6 +1081,50 @@ def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names, with_extr _assert_annotations_equal(annot, annot2) +@pytest.mark.parametrize("fmt", [pytest.param("csv", marks=needs_pandas), "txt"]) +def test_write_annotation_warn_heterogeneous(tmp_path, fmt): + """Test that CSV, and TXT annotation writers warn on heterogeneous dtypes.""" + annot = Annotations( + onset=[0.0, 9.0], + duration=[1.0, 2.425], + description=["AA", "BB"], + orig_time=_ORIG_TIME, + extras=[ + {"foo1": "a", "foo2": "a"}, + {"foo1": 1, "foo2": None}, + ], + ) + fname = tmp_path / f"annotations-annot.{fmt}" + with ( + pytest.warns(RuntimeWarning, match="'foo2' contains heterogeneous dtypes"), + pytest.warns(RuntimeWarning, match="'foo1' contains heterogeneous dtypes"), + ): + annot.save(fname) + + +def test_write_annotation_warn_heterogeneous_b(tmp_path): + """Additional cases of test_write_annotation_warn_heterogeneous + which can only be tested with TXT.""" + fmt = "txt" + annot = Annotations( + onset=[0.0, 9.0], + duration=[1.0, 2.425], + description=["AA", "BB"], + orig_time=_ORIG_TIME, + extras=[ + {"foo3": 1, "foo4": 1, "foo5": 1.0}, + {"foo3": 1.0, "foo4": None, "foo5": None}, + ], + ) + fname = tmp_path / f"annotations-annot.{fmt}" + with ( + pytest.warns(RuntimeWarning, match="'foo5' contains heterogeneous dtypes"), + pytest.warns(RuntimeWarning, match="'foo4' contains heterogeneous dtypes"), + pytest.warns(RuntimeWarning, match="'foo3' contains heterogeneous dtypes"), + ): + annot.save(fname) + + def test_broken_csv(tmp_path): """Test broken .csv that does not use timestamps.""" pytest.importorskip("pandas") From 960a7c2a29e224e0558fcc5bfe826f051e8f8cae Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 11:22:05 +0200 Subject: [PATCH 41/57] Fix annotations list repr --- mne/annotations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mne/annotations.py b/mne/annotations.py index 1f03f1ab38c..77a2c1bf16a 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -84,6 +84,9 @@ class _AnnotationsExtrasList(UserList): strings, integers, floats, or None. """ + def __repr__(self): + return repr(self.data) + @staticmethod def _validate_value( value: dict | _AnnotationsExtrasDict | None, From 1e58564d8576a4ff34087fb8865fb29f3b205a04 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 12:03:35 +0200 Subject: [PATCH 42/57] Infer TXT types using pandas, if possible --- mne/annotations.py | 42 +++++++++++++++++++++++++---------- mne/tests/test_annotations.py | 3 ++- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 77a2c1bf16a..5b041a2ff83 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1572,8 +1572,10 @@ def is_columns(x): columns = [[c.strip() for c in h[2:].split(",")] for h in header if is_columns(h)] - return None if not orig_values else orig_values[0], ( - None if not columns else columns[0] + return ( + None if not orig_values else orig_values[0], + (None if not columns else columns[0]), + len(header), ) @@ -1581,7 +1583,7 @@ def _read_annotations_txt(fname): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") out = np.loadtxt(fname, delimiter=",", dtype=np.bytes_, unpack=True) - orig_time, columns = _read_annotations_txt_parse_header(fname) + orig_time, columns, n_rows_header = _read_annotations_txt_parse_header(fname) ch_names = extras = None if len(out) == 0: onset, duration, desc = [], [], [] @@ -1617,15 +1619,31 @@ def _read_annotations_txt(fname): "ch_names", } if extra_columns: - extras = [ - { - col_name: _cast_extras_types( - out[col_map[col_name]][i].decode("UTF-8") - ) - for col_name in extra_columns - } - for i in range(len(onset)) - ] + pd = _check_pandas_installed(strict=False) + if pd: + df = pd.read_csv( + fname, + delimiter=",", + names=columns, + usecols=extra_columns, + skiprows=n_rows_header, + header=None, + keep_default_na=False, + ) + extras = df.to_dict(orient="records") + else: + warn( + "Extra fields found in the header but pandas is not installed. " + "Therefor the dtypes of the extra fields can not automatically " + "be infered so they will be loaded as strings." + ) + extras = [ + { + col_name: out[col_map[col_name]][i].decode("UTF-8") + for col_name in extra_columns + } + for i in range(len(onset)) + ] onset = [float(o.decode()) for o in np.atleast_1d(onset)] duration = [float(d.decode()) for d in np.atleast_1d(duration)] diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index c9110bcaee0..aa6efc13e69 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1218,9 +1218,10 @@ def test_read_annotation_txt_header(tmp_path): fname = tmp_path / "header.txt" with open(fname, "w") as f: f.write(content) - orig_time, _ = _read_annotations_txt_parse_header(fname) + orig_time, _, n_rows_header = _read_annotations_txt_parse_header(fname) want = datetime.fromtimestamp(1038942071.7201, timezone.utc) assert orig_time == want + assert n_rows_header == 5 def test_read_annotation_txt_one_segment(tmp_path): From 28448aea148d70ebf82edfe5ee5f5dd15728ab19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 May 2025 10:04:20 +0000 Subject: [PATCH 43/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_annotations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index aa6efc13e69..2472ac78a29 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1104,7 +1104,8 @@ def test_write_annotation_warn_heterogeneous(tmp_path, fmt): def test_write_annotation_warn_heterogeneous_b(tmp_path): """Additional cases of test_write_annotation_warn_heterogeneous - which can only be tested with TXT.""" + which can only be tested with TXT. + """ fmt = "txt" annot = Annotations( onset=[0.0, 9.0], From 50f5385aafaf7e4b27396f2a9f16199b0ae8cda5 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 12:07:34 +0200 Subject: [PATCH 44/57] Fix docstring format --- mne/tests/test_annotations.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 2472ac78a29..b4e20976487 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1103,8 +1103,9 @@ def test_write_annotation_warn_heterogeneous(tmp_path, fmt): def test_write_annotation_warn_heterogeneous_b(tmp_path): - """Additional cases of test_write_annotation_warn_heterogeneous - which can only be tested with TXT. + """Additional cases for test_write_annotation_warn_heterogeneous + + These cases are only compatible with the TXT writer. """ fmt = "txt" annot = Annotations( From cd716b92fec961bb280e6f5c0beedb623e66dc7c Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 12:08:58 +0200 Subject: [PATCH 45/57] Fix docstring format --- mne/tests/test_annotations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index b4e20976487..7a9a0faea43 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1103,7 +1103,7 @@ def test_write_annotation_warn_heterogeneous(tmp_path, fmt): def test_write_annotation_warn_heterogeneous_b(tmp_path): - """Additional cases for test_write_annotation_warn_heterogeneous + """Additional cases for test_write_annotation_warn_heterogeneous. These cases are only compatible with the TXT writer. """ From ac07fa5753ffd5d5a248bc4b90ea923c13ff48f7 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 12:11:06 +0200 Subject: [PATCH 46/57] codespell --- mne/annotations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 5b041a2ff83..7a2b35ca667 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1462,11 +1462,11 @@ def _cast_extras_types(val): def _read_annotations_csv(fname): """Read annotations from csv. - The dtypes of the extra fields will automatically be infered + The dtypes of the extra fields will automatically be inferred by pandas. If some fields have heterogeneous types on the different rows, this automatic inference may return unexpecterd types. - If you need to save heterogeneous extra dtypes, we recomend + If you need to save heterogeneous extra dtypes, we recommend saving to FIF. Parameters @@ -1634,8 +1634,8 @@ def _read_annotations_txt(fname): else: warn( "Extra fields found in the header but pandas is not installed. " - "Therefor the dtypes of the extra fields can not automatically " - "be infered so they will be loaded as strings." + "Therefore the dtypes of the extra fields can not automatically " + "be inferred so they will be loaded as strings." ) extras = [ { From 09b43cac0dfcb6713cf5f40a7eed6d9a4c17c4f1 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sat, 3 May 2025 13:25:41 +0200 Subject: [PATCH 47/57] Remove unused function --- mne/annotations.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 7a2b35ca667..5a9807b6573 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1445,20 +1445,6 @@ def read_annotations( return annotations -def _cast_extras_types(val): - """Cast types to int or float.""" - if val == "": - return None - try: - out = int(val) - except (ValueError, TypeError): - try: - out = float(val) - except (ValueError, TypeError): - out = val - return out - - def _read_annotations_csv(fname): """Read annotations from csv. From 445bf013150d0f4467aaf652494eb20692e70eed Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Sun, 4 May 2025 22:31:48 +0200 Subject: [PATCH 48/57] Skip mypy check --- mne/annotations.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 5a9807b6573..73401515c22 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -111,12 +111,7 @@ def __init__(self, initlist=None): def __setitem__( # type: ignore[override] self, key: int | slice, - value: ( - dict - | _AnnotationsExtrasDict - | None - | Iterable[dict | _AnnotationsExtrasDict | None] - ), + value, ) -> None: _validate_type(key, (int, slice), "key", "int or slice") if isinstance(key, int): From 6de0f4cebc65ea9650314f46f889e8190b075900 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Date: Fri, 9 May 2025 08:42:43 +0200 Subject: [PATCH 49/57] Apply review suggestion (spelling) Co-authored-by: Daniel McCloy --- mne/annotations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/annotations.py b/mne/annotations.py index 73401515c22..e7dea3cea71 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1445,7 +1445,7 @@ def _read_annotations_csv(fname): The dtypes of the extra fields will automatically be inferred by pandas. If some fields have heterogeneous types on the - different rows, this automatic inference may return unexpecterd + different rows, this automatic inference may return unexpected types. If you need to save heterogeneous extra dtypes, we recommend saving to FIF. From 9e67e3b7c8349d198855850e3b29bfa4cf60e649 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Date: Wed, 21 May 2025 21:59:25 +0200 Subject: [PATCH 50/57] Apply suggestions from code review [circle full] Co-authored-by: Eric Larson --- mne/annotations.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index e7dea3cea71..e7b1d32e666 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -66,11 +66,11 @@ class _AnnotationsExtrasDict(UserDict): """ def __setitem__(self, key: str, value: str | int | float | None) -> None: - _validate_type(key, str, "key", "string") + _validate_type(key, str, "key") if key in ("onset", "duration", "description", "ch_names"): raise ValueError(f"Key '{key}' is reserved and cannot be used in extras.") _validate_type( - value, (str, int, float, None), "value", "string, int, float or None" + value, (str, int, float, None), "value", ) super().__setitem__(key, value) @@ -378,7 +378,7 @@ class Annotations: """ # noqa: E501 def __init__( - self, onset, duration, description, orig_time=None, ch_names=None, extras=None + self, onset, duration, description, orig_time=None, ch_names=None, *, extras=None ): self._orig_time = _handle_meas_date(orig_time) self.onset, self.duration, self.description, self.ch_names, self._extras = ( @@ -408,7 +408,7 @@ def extras(self, extras): @property def _extras_columns(self) -> set[str]: """The set containing all the keys in all extras dicts.""" - return {k for d in self.extras for k in d.keys()} + return {k for d in self.extras for k in d} def __eq__(self, other): """Compare to another Annotations instance.""" @@ -507,7 +507,7 @@ def __getitem__(self, key, *, with_ch_names=None, with_extras=True): ) @fill_doc - def append(self, onset, duration, description, ch_names=None, extras=None): + def append(self, onset, duration, description, ch_names=None, *, extras=None): """Add an annotated segment. Operates inplace. Parameters @@ -971,7 +971,7 @@ def set_annotations(self, annotations, on_missing="raise", *, verbose=None): self._annotations = new_annotations return self - def get_annotations_per_epoch(self, with_extras=False): + def get_annotations_per_epoch(self, *, with_extras=False): """Get a list of annotations that occur during each epoch. Parameters @@ -1056,7 +1056,7 @@ def get_annotations_per_epoch(self, with_extras=False): epoch_annot_list[epo_ix].append(annot) return epoch_annot_list - def add_annotations_to_metadata(self, overwrite=False, with_extras=True): + def add_annotations_to_metadata(self, overwrite=False, *, with_extras=True): """Add raw annotations into the Epochs metadata data frame. Adds three columns to the ``metadata`` consisting of a list From 4acce8d7f0f6c618b69363dad383f20cb4f02b1c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 May 2025 19:59:43 +0000 Subject: [PATCH 51/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/annotations.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index e7b1d32e666..ab70a9bf4da 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -70,7 +70,9 @@ def __setitem__(self, key: str, value: str | int | float | None) -> None: if key in ("onset", "duration", "description", "ch_names"): raise ValueError(f"Key '{key}' is reserved and cannot be used in extras.") _validate_type( - value, (str, int, float, None), "value", + value, + (str, int, float, None), + "value", ) super().__setitem__(key, value) @@ -378,7 +380,14 @@ class Annotations: """ # noqa: E501 def __init__( - self, onset, duration, description, orig_time=None, ch_names=None, *, extras=None + self, + onset, + duration, + description, + orig_time=None, + ch_names=None, + *, + extras=None, ): self._orig_time = _handle_meas_date(orig_time) self.onset, self.duration, self.description, self.ch_names, self._extras = ( From 362159502062f55b5606c386c4721fff750affb8 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Wed, 21 May 2025 15:05:54 -0500 Subject: [PATCH 52/57] [circle full] From 740694804c0b158def50f07ef838709da6b84c2f Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 22 May 2025 14:45:51 +0200 Subject: [PATCH 53/57] Fix positional arguments [circle full] --- mne/annotations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index ab70a9bf4da..171b9a510fb 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -476,7 +476,7 @@ def __iadd__(self, other): other.duration, other.description, other.ch_names, - other.extras, + extras=other.extras, ) def __iter__(self): @@ -1498,7 +1498,7 @@ def _read_annotations_csv(fname): extras = None if len(extra_columns) > 0: extras = df[extra_columns].to_dict(orient="records") - return Annotations(onset, duration, description, orig_time, ch_names, extras) + return Annotations(onset, duration, description, orig_time, ch_names, extras=extras) def _read_brainstorm_annotations(fname, orig_time=None): @@ -1691,7 +1691,7 @@ def _read_annotations_fif(fid, tree): if extras is not None: assert len(extras) == len(onset) annotations = Annotations( - onset, duration, description, orig_time, ch_names, extras + onset, duration, description, orig_time, ch_names, extras=extras ) return annotations From 8af93a201b65c40032405493e9b81f9016a1cd2b Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Mon, 26 May 2025 16:38:29 +0200 Subject: [PATCH 54/57] Add test for positional args --- mne/tests/test_annotations.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 7a9a0faea43..6c1100a4780 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -7,6 +7,7 @@ from datetime import datetime, timedelta, timezone from itertools import repeat from pathlib import Path +import re import numpy as np import pytest @@ -1984,3 +1985,27 @@ def test_extras_list_raises(key, value, expected_error, match): extras.extend([{key: value}]) with pytest.raises(expected_error, match=match): extras += [{key: value}] + + +def test_annotations_positional_args(): + annot = Annotations([0], [1], ["a"]) + _ = Annotations([0], [1], ["a"], None) + _ = Annotations([0], [1], ["a"], None, None) + with pytest.raises( + TypeError, + match=re.escape( + "Annotations.__init__() takes from 4 to 6 " + "positional arguments but 7 were given" + ), + ): + _ = Annotations([0], [1], ["a"], None, None, [{"foo": "bar"}]) + annot.append([0], [1], ["a"]) + annot.append([0], [1], ["a"], None) + with pytest.raises( + TypeError, + match=re.escape( + "Annotations.append() takes from 4 to 5 " + "positional arguments but 6 were given" + ), + ): + annot.append([0], [1], ["a"], None, [{"foo": "bar"}]) From e7aa63216aaf29fddbd7d4f63bac6133ccb64754 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 May 2025 14:38:51 +0000 Subject: [PATCH 55/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_annotations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6c1100a4780..a535bd5dde0 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -2,12 +2,12 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import re import sys from collections import OrderedDict from datetime import datetime, timedelta, timezone from itertools import repeat from pathlib import Path -import re import numpy as np import pytest From 1668571f849270a2ebe1b365c13f4a07806e14e6 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Mon, 26 May 2025 16:43:51 +0200 Subject: [PATCH 56/57] Add missing docstring --- mne/tests/test_annotations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index a535bd5dde0..96b3644af35 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1988,6 +1988,7 @@ def test_extras_list_raises(key, value, expected_error, match): def test_annotations_positional_args(): + """Test that Annotations positional arguments work as expected.""" annot = Annotations([0], [1], ["a"]) _ = Annotations([0], [1], ["a"], None) _ = Annotations([0], [1], ["a"], None, None) From efff25ae37b44608ae51941a3f7fb477f46746da Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Mon, 2 Jun 2025 09:31:55 -0500 Subject: [PATCH 57/57] revert new test (unneeded) --- mne/tests/test_annotations.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 96b3644af35..7a9a0faea43 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -2,7 +2,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import re import sys from collections import OrderedDict from datetime import datetime, timedelta, timezone @@ -1985,28 +1984,3 @@ def test_extras_list_raises(key, value, expected_error, match): extras.extend([{key: value}]) with pytest.raises(expected_error, match=match): extras += [{key: value}] - - -def test_annotations_positional_args(): - """Test that Annotations positional arguments work as expected.""" - annot = Annotations([0], [1], ["a"]) - _ = Annotations([0], [1], ["a"], None) - _ = Annotations([0], [1], ["a"], None, None) - with pytest.raises( - TypeError, - match=re.escape( - "Annotations.__init__() takes from 4 to 6 " - "positional arguments but 7 were given" - ), - ): - _ = Annotations([0], [1], ["a"], None, None, [{"foo": "bar"}]) - annot.append([0], [1], ["a"]) - annot.append([0], [1], ["a"], None) - with pytest.raises( - TypeError, - match=re.escape( - "Annotations.append() takes from 4 to 5 " - "positional arguments but 6 were given" - ), - ): - annot.append([0], [1], ["a"], None, [{"foo": "bar"}])