diff --git a/naplib/preprocessing/rereference.py b/naplib/preprocessing/rereference.py index 97b5293d..6ebf8350 100644 --- a/naplib/preprocessing/rereference.py +++ b/naplib/preprocessing/rereference.py @@ -116,7 +116,7 @@ def _rereference(data_arr, method='avg', return_ref=False): return data_rereferenced -def make_contact_rereference_arr(channelnames, extent=None): +def make_contact_rereference_arr(channelnames, extent=None, grid_sizes={}): """ Create grid which defines re-referencing scheme based on electrodes being on the same contact as each other. @@ -128,13 +128,17 @@ def make_contact_rereference_arr(channelnames, extent=None): be alphanumeric, with any numbers only being on the right. 2) The numeric portion specifies a different electrode number, while the character portion in the left of the channelname specifies the contact name. E.g. ['RT1','RT2','RT3','Ls1','Ls2'] indicates two contacts, the first with 3 electrodes - and the second with 2 electrodes. 3) Electrodes from the same contact must be contiguous. + and the second with 2 electrodes. extent : int, optional, default=None If provided, then only contacts from the same group which are within ``extent`` electrodes away - from each other (inclusive) are still grouped together. Only used if ``method='contact'``. For - example, if ``extent=1``, only the nearest electrode on either side of a given electrode on the - same contact is still grouped with it. For example, extent=1 produces the traditional local - average reference scheme. + from each other (inclusive) are still grouped together. For example, if ``extent=1``, only the + nearest electrode on either side of a given electrode on the same contact is still grouped with it. + This ``extent=1`` produces the traditional local average reference scheme. + The default ``extent=None`` produces the traditional common average reference scheme. + grid_sizes : dict, optional, default={} + If provided, contains {'contact_name': (nrow, ncol)} values for any known ECoG grid sizes. + E.g. {'GridA': (8, 16)} indicates that electrodes on contact 'GridA' are arranged in an 8 x 16 grid, + which is needed to determine adjacent electrodes for local average referencing with ``extent >= 1``. Returns ------- @@ -145,18 +149,89 @@ def make_contact_rereference_arr(channelnames, extent=None): -------- rereference """ - contact_arrays = pd.Series([x.rstrip('0123456789') for x in channelnames]) - connections = np.zeros((len(contact_arrays),) * 2, dtype=float) - for _, inds in contact_arrays.groupby(contact_arrays): - for i in inds.index: - connections[i, inds.index] = 1.0 + def _find_adjacent_numbers(a, b, number, extent): + ''' + Used to determine electrodes for local averaging ECoG grid" + ''' + # Validate if the number is within the valid range + if number < 1 or number > a * b: + raise ValueError("The number is outside the range of the grid.") - # remove longer than extent if desired - if extent is not None: - if extent < 1: - raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}') - connections *= np.tri(*connections.shape, k=extent) - connections *= np.fliplr(np.flipud(np.tri(*connections.shape, k=extent))) - connections = connections - + # Calculate the row and column of the given number + row = (number - 1) // b + col = (number - 1) % b + + # Find all adjacent numbers within the extent + adjacent_numbers = [] + for dr in range(-extent, extent + 1): # Rows within the extent + for dc in range(-extent, extent + 1): # Columns within the extent + if dr == 0 and dc == 0: + continue # Skip the number itself + new_row, new_col = row + dr, col + dc + if 0 <= new_row < a and 0 <= new_col < b: + adjacent_num = new_row * b + new_col + 1 + adjacent_numbers.append(adjacent_num) + + return np.array(adjacent_numbers, dtype=int) + + connections = np.zeros((len(channelnames),) * 2, dtype=float) + channelnames = np.array(channelnames) + contact_arrays = np.array([x.rstrip('0123456789') for x in channelnames]) + contacts = np.unique(contact_arrays) + # Determine the channel numbers on each contact + ch_per_contact = {contact:[int(x.replace(contact,'')) for x in channelnames + if x.rstrip('0123456789')==contact] + for contact in contacts} + + if extent is None: + # Common average referencing per electrode array (ECoG grid or sEEG shank) + # CAR will end up subtracting parts of channel ch from itself + for contact in contacts: + for ch in ch_per_contact[contact]: + curr = np.where(channelnames==f'{contact}{ch}')[0] + inds = np.where(contact_arrays==contact)[0] + connections[curr,inds] = 1 + elif extent < 1: + raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}') + else: + # Local average referencing within each electrode array + # LAR will NOT subtract parts of channel ch from itself + for contact in contacts: + for ch in ch_per_contact[contact]: + # Local referencing for ECoG grids + if 'grid' in contact.lower(): + num_ch = len(ch_per_contact[contact]) + side = np.sqrt(num_ch) + half_side = np.sqrt(num_ch/2) + # Check grid_sizes dict + if contact in grid_sizes: + nrows, ncols = grid_sizes[contact] + # Assume a square + elif np.isclose(side, int(side)): + nrows, ncols = side, side + # Assume a 1 x 2 rectangle + elif np.isclose(half_side, int(half_side)): + nrows, ncols = half_side, half_side*2 + else: + raise Exception(f'Cannot determine {contact} layout. Please include layout in `grid_sizes`') + adjacent = _find_adjacent_numbers(nrows, ncols, ch, extent) + curr = np.where(channelnames==f'{contact}{ch}')[0] + inds = [] + for adj in adjacent: + inds.append(np.where(channelnames==f'{contact}{adj}')[0]) + + # Local referencing for sEEG shanks and strips + else: + curr = np.where(channelnames==f'{contact}{ch}')[0] + inds = [] + for cc in range(ch-extent, ch+extent+1): + if cc != ch: + inds.append(np.where(channelnames==f'{contact}{cc}')[0]) + + inds = np.concatenate(inds) + if len(inds) < 1: + print(f'{contact}{cc} has no re-references.') + else: + connections[curr,inds] = 1 + return connections diff --git a/tests/preprocessing/test_rereference.py b/tests/preprocessing/test_rereference.py index d9fa02a5..39c4bfa3 100644 --- a/tests/preprocessing/test_rereference.py +++ b/tests/preprocessing/test_rereference.py @@ -4,10 +4,29 @@ def test_create_contact_rereference_arr(): - expected = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]]) - g = ['LT1','LT2','RT1','RT2'] - arr = make_contact_rereference_arr(g) + expected = np.array([[0,1,0,0,0,0,0,0], + [1,0,0,0,0,0,0,0], + [0,0,0,1,0,0,0,0], + [0,0,1,0,0,0,0,0], + [0,0,0,0,0,1,1,1], + [0,0,0,0,1,0,1,1], + [0,0,0,0,1,1,0,1], + [0,0,0,0,1,1,1,0], + ]) + expected1 = np.array([[1,1,0,0,0,0,0,0], + [1,1,0,0,0,0,0,0], + [0,0,1,1,0,0,0,0], + [0,0,1,1,0,0,0,0], + [0,0,0,0,1,1,1,1], + [0,0,0,0,1,1,1,1], + [0,0,0,0,1,1,1,1], + [0,0,0,0,1,1,1,1], + ]) + g = ['LT1','LT2','GridA1','GridA2'] + [f'GridB{n}' for n in range(1,5)] + arr = make_contact_rereference_arr(g, extent=1, grid_sizes={'GridA':(1,2)}) + arr1 = make_contact_rereference_arr(g) assert np.allclose(expected, arr) + assert np.allclose(expected1, arr1) def test_rereference_avg(): arr = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]])