diff --git a/examples/brain_plotting/plot_intracranial_electrodes.py b/examples/brain_plotting/plot_intracranial_electrodes.py index 7c517e58..f502b263 100644 --- a/examples/brain_plotting/plot_intracranial_electrodes.py +++ b/examples/brain_plotting/plot_intracranial_electrodes.py @@ -77,6 +77,20 @@ dist_from_HG = brain.distance_from_region(coords, isleft, region='pmHG', metric='surf') print(dist_from_HG) +############################################################################### +# Smoothly interpolate values over the brain's surface from electrodes + +# As an example, we will use the y coordinate of the electrode +values_per_electrode = coords[:,1] - coords[:,1].min() + +# Interpolate onto only the temporal lobe, using the 5 nearest neighbor interpolation with +# a maximum distance of 10mm +brain.interpolate_electrodes_onto_brain(coords, values_per_electrode, isleft, roi='temporal', k=5, max_dist=10) + +# Plot the overlay for just the left hemisphere +fig, axes = plot_brain_overlay(brain, cmap='Reds', view='lateral', figsize=(12,6), hemi='lh') +plt.show() + ############################################################################### # Create a brain with the inflated surface for plotting brain = Brain('inflated', subject_dir='./fsaverage/').split_hg('midpoint').split_stg().simplify_labels() diff --git a/naplib/__init__.py b/naplib/__init__.py index 73e426c7..923162b0 100644 --- a/naplib/__init__.py +++ b/naplib/__init__.py @@ -56,5 +56,5 @@ def set_logging(level: Union[int, str]): from .data import Data, join_fields, concat import naplib.naplab -__version__ = "2.2.0" +__version__ = "2.3.0" diff --git a/naplib/localization/freesurfer.py b/naplib/localization/freesurfer.py index 81ea1340..72088aa1 100644 --- a/naplib/localization/freesurfer.py +++ b/naplib/localization/freesurfer.py @@ -111,6 +111,11 @@ } region2num = {v: k for k, v in num2region.items()} +temporal_regions_nums = [33, 34, 35, 36, 74, 41, 43, 72, 73, 38, 37, 75, 76, 77, 78, 79, 80, 81] +temporal_regions_superlist = [num2region[num] for num in temporal_regions_nums] +temporal_regions_superlist += ['alHG','pmHG','HG','TTS','PT','PP','MTG','ITG','mSTG','pSTG','STG','STS','T.Pole'] + + num2region_mni = { 0: 'unknown', 1: 'bankssts', @@ -766,6 +771,85 @@ def paint_overlay(self, labels, value=1): self.has_overlay[verts == 1] = True self.has_overlay_cells[add_overlay == 1] = True return self + + def interpolate_electrodes_onto_brain(self, coords, values, k, max_dist, roi='all'): + """ + Use electrode coordinates to interpolate 1-dimensional values corresponding + to each electrode onto the brain's surface. + + Parameters + ---------- + coords : np.ndarray (elecs, 3) + 3D coordinates of electrodes + values : np.ndarray (elecs,) + Value for each electrode + k : int + Number of nearest neighbors to consider + max_dist : float + Maximum distance outside of which nearest neighbors will be ignored + roi : list of strings, or string in {'all', 'temporal'}, default='all' + Regions to allow interpolation over. By default, the entire brain surface + is allowed. Can also be specified as a list of string labels (drawing from self.label_names) + + Notes + ----- + After running this function, you can use the visualization function ``plot_brain_overlay`` + for a quick matplotlib plot, or you can extract the surface values from the ``self.overlay`` + attribute for plotting with another tool like pysurfer. + """ + + if isinstance(roi, str) and roi == 'all': + roi_list = self.label_names + elif isinstance(roi, str) and roi == 'temporal': + if self.atlas == 'MNI152': + raise ValueError("roi='temporal' is not supported for MNI brain. Must specify list of specific region names") + roi_list = temporal_regions_superlist + else: + roi_list = roi + assert isinstance(roi, list) + + roi_list_subset = [x for x in roi_list if x in self.label_names] + zones_to_include, _, _ = self.zones(roi_list_subset) + + # Euclidean distances from each surface vertex to each coordinate + dists = cdist(self.surf[0], coords) + sorted_dists = np.sort(dists, axis=-1)[:, :k] + indices = np.argsort(dists, axis=-1)[:, :k] # get closest k electrodes to each vertex + + # Mask out distances greater than max_dist + valid_mask = sorted_dists <= max_dist + + # Retrieve the corresponding values using indices + neighbor_values = values[indices] + + # Mask invalid values + masked_values = np.where(valid_mask, neighbor_values, np.nan) + masked_distances = np.where(valid_mask, sorted_dists, np.nan) + + # Compute weights: inverse distance weighting (avoiding division by zero) + weights = np.where(valid_mask, 1 / (masked_distances + 1e-10), 0) + + # # Compute weighted sum and normalize by total weight per vertex + weighted_sum = np.nansum(masked_values * weights, axis=1) + total_weight = np.nansum(weights, axis=1) + + # # Normalize to get final smoothed values + updated_vertices = np.logical_and(total_weight > 0, zones_to_include) + total_weight[~updated_vertices] += 1e-10 # this just gets ride of the division by zero warning, but doesn't affect result since these values are turned to nan anyway + smoothed_values = np.where(updated_vertices, weighted_sum / total_weight, np.nan) + + # update the surface vertices and triangle attributes with the values + verts = updated_vertices.astype('float') + trigs = np.zeros(self.n_trigs, dtype=float) + for i in range(self.n_trigs): + trigs[i] = np.mean([verts[self.trigs[i, j]] != 0 for j in range(3)]) + + self.overlay[updated_vertices] = smoothed_values[updated_vertices] + self.has_overlay[updated_vertices] = True + self.has_overlay_cells[trigs == 1] = True + + return self + def mark_overlay(self, verts, value=1, inner_radius=0.8, taper=True): """ @@ -801,6 +885,11 @@ def set_visible(self, labels, min_alpha=0): self.keep_visible_cells = self.alpha > min_alpha self.alpha = np.maximum(self.alpha, min_alpha) return self + + def reset_overlay_except(self, labels): + keep_visible, self.alpha, _ = self.zones(labels, min_alpha=0) + self.overlay[~keep_visible] = 0 + return self class Brain: @@ -1134,7 +1223,7 @@ def mark_overlay(self, verts, isleft, value=1, inner_radius=0.8, taper=True): def set_visible(self, labels, min_alpha=0): """ - Set certain regions as visible with a float label. + Set certain regions as visible with a float label, and the rest will be invisible. Parameters ---------- @@ -1150,6 +1239,61 @@ def set_visible(self, labels, min_alpha=0): self.lh.set_visible(labels, min_alpha) self.rh.set_visible(labels, min_alpha) return self + + def reset_overlay_except(self, labels): + """ + Keep certain regions and the rest as colorless. + + Parameters + ---------- + labels : str | list[str] + Label(s) to set as visible. + + Returns + ------- + self : instance of self + """ + self.lh.reset_overlay_except(labels) + self.rh.reset_overlay_except(labels) + return self + + def interpolate_electrodes_onto_brain(self, coords, values, isleft=None, k=10, max_dist=10, roi='all', reset_overlay_first=True): + """ + Use electrode coordinates to interpolate 1-dimensional values corresponding + to each electrode onto the brain's surface. + + Parameters + ---------- + coords : np.ndarray (elecs, 3) + 3D coordinates of electrodes + values : np.ndarray (elecs,) + Value for each electrode + isleft : np.ndarray (elecs,), optional + If provided, specifies a boolean which is True for each electrode that is in the left hemisphere. + If not given, this will be inferred from the first dimension of the coords (negative is left). + k : int, default=10 + Number of nearest neighbors to consider + max_dist : float, default=10 + Maximum distance (in mm) outside of which nearest neighbors will be ignored + roi : list of strings, or string in {'all', 'temporal'}, default='all' + Regions to allow interpolation over. By default, the entire brain surface + is allowed. Can also be specified as a list of string labels (drawing from self.lh.label_names) + reset_overlay_first : bool, default=True + If True (default), reset the overlay before creating a new overlay + + Notes + ----- + After running this function, you can use the visualization function ``plot_brain_overlay`` + for a quick matplotlib plot, or you can extract the surface values from the ``self.lh.overlay`` + and ``self.rh.overlay`` attributes, etc, for plotting with another tool like pysurfer or plotly. + """ + if reset_overlay_first: + self.reset_overlay() + if isleft is None: + isleft = coords[:,0] < 0 + self.lh.interpolate_electrodes_onto_brain(coords[isleft], values[isleft], k=k, max_dist=max_dist, roi=roi) + self.rh.interpolate_electrodes_onto_brain(coords[~isleft], values[~isleft], k=k, max_dist=max_dist, roi=roi) + return self def get_nearest_vert_index(coords, isleft, surf_lh, surf_rh, verbose=False): @@ -1190,4 +1334,3 @@ def find_closest_vertices(surface_coords, point_coords): point_coords = np.atleast_2d(point_coords) dists = cdist(surface_coords, point_coords) return np.argmin(dists, axis=0), np.min(dists, axis=0) - diff --git a/naplib/utils/surfdist.py b/naplib/utils/surfdist.py index 8fef3c2b..0db54a93 100644 --- a/naplib/utils/surfdist.py +++ b/naplib/utils/surfdist.py @@ -126,6 +126,8 @@ def surfdist_viz( bg_on_stat=False, figsize=None, ax=None, + vmin=None, + vmax=None, ): """Visualize results on cortical surface using matplotlib. @@ -232,8 +234,10 @@ def surfdist_viz( # Ensure symmetric colour range, based on Nilearn helper function: # https://github.com/nilearn/nilearn/blob/master/nilearn/plotting/img_plotting.py#L52 - vmax = max(-np.nanmin(stat_map_faces), np.nanmax(stat_map_faces)) - vmin = -vmax + if vmax is None: + vmax = max(-np.nanmin(stat_map_faces), np.nanmax(stat_map_faces)) + if vmin is None: + vmin = -vmax if threshold is not None: kept_indices = np.where(abs(stat_map_faces) >= threshold)[0] diff --git a/naplib/visualization/brain_plots.py b/naplib/visualization/brain_plots.py index 1b5c3cde..17dd6560 100644 --- a/naplib/visualization/brain_plots.py +++ b/naplib/visualization/brain_plots.py @@ -74,24 +74,26 @@ def _view(hemi, mode: str = "lateral", backend: str = "mpl"): raise ValueError(f"Unknown `mode`: {mode}.") -def _plot_hemi(hemi, cmap="coolwarm", ax=None, denorm=False, view="best"): +def _plot_hemi(hemi, cmap="coolwarm", ax=None, view="best", thresh=None, vmin=None, vmax=None): surfdist_viz( *hemi.surf, hemi.overlay, *_view(hemi.hemi, mode=view), - cmap=cmap(hemi.overlay.max()) if denorm else cmap, - threshold=0.25, + cmap=cmap, + threshold=thresh, alpha=hemi.alpha, bg_map=hemi.sulc, bg_on_stat=True, ax=ax, + vmin=vmin, + vmax=vmax ) ax.axes.set_axis_off() ax.grid(False) def plot_brain_overlay( - brain, cmap="coolwarm", ax=None, denorm=False, view="best", **kwargs + brain, cmap="coolwarm", ax=None, hemi='both', denorm=False, view="best", cmap_quantile=1.0, **kwargs ): """ Plot brain overlay on the 3D cortical surface using matplotlib. @@ -106,12 +108,21 @@ def plot_brain_overlay( Colormap to use. ax : list | tuple of matplotlib Axes 2 Axes to plot the left and right hemispheres with. + hemi : {'both', 'lh', 'rh'}, default='both' + Hemisphere(s) to plot. If 'both', then 2 subplots are created, one for each hemisphere. + Otherwise only one hemisphere is displayed with its overlay. denorm : bool, default=False Whether to center the overlay labels around 0 or not before sending to the colormap. view : {'lateral','medial','frontal','top','best'}, default='best' Which view to plot for each hemisphere. + cmap_quantile : float | tuple of floats (optional), default=1.0 + If a single float less than 1, will only use the central ``cmap_quantile`` portion of the range + of values to create the vmin and vmax for the colormap. For example, if set to 0.95, + then only the middle 95% of the values will be used to set the range of the colormap. If a tuple, + then it should specify 2 quantiles, one for the vmin and one for the vmax, such as (0.025, 0.975), + which would be equivalent to passing a single value of 0.95. **kwargs : kwargs - Any other kwargs to pass to matplotlib.pyplot.figure + Any other kwargs to pass to matplotlib.pyplot.figure (such as figsize) Returns ------- @@ -121,14 +132,59 @@ def plot_brain_overlay( """ fig = plt.figure(**kwargs) if ax is None: - ax1 = fig.add_subplot(1, 2, 1, projection="3d") - ax2 = fig.add_subplot(1, 2, 2, projection="3d") - ax = (ax1, ax2) + if hemi in ['both', 'b']: + ax1 = fig.add_subplot(1, 2, 1, projection="3d") + ax2 = fig.add_subplot(1, 2, 2, projection="3d") + ax = (ax1, ax2) + else: + ax = fig.add_subplot(1, 1, 1, projection="3d") + if hemi in ['left','lh']: + ax = [ax, None] + elif hemi in ['right','rh']: + ax = [None, ax] else: - ax1, ax2 = ax + if hemi in ['both', 'b']: + assert len(ax) == 2 - _plot_hemi(brain.lh, cmap, ax1, denorm, view=view) - _plot_hemi(brain.rh, cmap, ax2, denorm, view=view) + + if cmap_quantile is not None: + if isinstance(cmap_quantile, float): + assert cmap_quantile <= 1 and cmap_quantile > 0 + cmap_diff = (1.0 - cmap_quantile) / 2. + vmin_l = np.quantile(brain.lh.overlay[brain.lh.overlay!=0], cmap_diff) + vmax_l = np.quantile(brain.lh.overlay[brain.lh.overlay!=0], 1.0 - cmap_diff) + vmin_r = np.quantile(brain.rh.overlay[brain.rh.overlay!=0], cmap_diff) + vmax_r = np.quantile(brain.rh.overlay[brain.rh.overlay!=0], 1.0 - cmap_diff) + elif isinstance(cmap_quantile, tuple): + vmin_l = np.quantile(brain.lh.overlay[brain.lh.overlay!=0], cmap_quantile[0]) + vmax_l = np.quantile(brain.lh.overlay[brain.lh.overlay!=0], cmap_quantile[1]) + vmin_r = np.quantile(brain.rh.overlay[brain.rh.overlay!=0], cmap_quantile[0]) + vmax_r = np.quantile(brain.rh.overlay[brain.rh.overlay!=0], cmap_quantile[1]) + else: + raise ValueError('cmap_quantile must be either a float or a tuple') + else: + vmin_l = brain.lh.overlay[brain.lh.overlay!=0].min() + vmax_l = brain.lh.overlay[brain.lh.overlay!=0].max() + vmin_r = brain.rh.overlay[brain.rh.overlay!=0].min() + vmax_r = brain.rh.overlay[brain.rh.overlay!=0].max() + + + # determine vmin and vmax + if hemi in ['both', 'b']: + vmin = min([vmin_l, vmin_r]) + vmax = max([vmax_l, vmax_r]) + elif hemi in ['left','lh']: + vmin = vmin_l + vmax = vmax_l + elif hemi in ['right','rh']: + vmin = vmin_r + vmax = vmax_r + + + if ax[0] is not None: + _plot_hemi(brain.lh, cmap, ax[0], view=view, vmin=vmin, vmax=vmax) + if ax[1] is not None: + _plot_hemi(brain.rh, cmap, ax[1], view=view, vmin=vmin, vmax=vmax) return fig, ax