Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 94 additions & 19 deletions naplib/preprocessing/rereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _rereference(data_arr, method='avg', return_ref=False):
return data_rereferenced


def make_contact_rereference_arr(channelnames, extent=None):
def make_contact_rereference_arr(channelnames, extent=None, grid_sizes={}):
"""
Create grid which defines re-referencing scheme based on electrodes being on the same contact as
each other.
Expand All @@ -128,13 +128,17 @@ def make_contact_rereference_arr(channelnames, extent=None):
be alphanumeric, with any numbers only being on the right. 2) The numeric portion specifies a
different electrode number, while the character portion in the left of the channelname specifies the
contact name. E.g. ['RT1','RT2','RT3','Ls1','Ls2'] indicates two contacts, the first with 3 electrodes
and the second with 2 electrodes. 3) Electrodes from the same contact must be contiguous.
and the second with 2 electrodes.
extent : int, optional, default=None
If provided, then only contacts from the same group which are within ``extent`` electrodes away
from each other (inclusive) are still grouped together. Only used if ``method='contact'``. For
example, if ``extent=1``, only the nearest electrode on either side of a given electrode on the
same contact is still grouped with it. For example, extent=1 produces the traditional local
average reference scheme.
from each other (inclusive) are still grouped together. For example, if ``extent=1``, only the
nearest electrode on either side of a given electrode on the same contact is still grouped with it.
This ``extent=1`` produces the traditional local average reference scheme.
The default ``extent=None`` produces the traditional common average reference scheme.
grid_sizes : dict, optional, default={}
If provided, contains {'contact_name': (nrow, ncol)} values for any known ECoG grid sizes.
E.g. {'GridA': (8, 16)} indicates that electrodes on contact 'GridA' are arranged in an 8 x 16 grid,
which is needed to determine adjacent electrodes for local average referencing with ``extent >= 1``.

Returns
-------
Expand All @@ -145,18 +149,89 @@ def make_contact_rereference_arr(channelnames, extent=None):
--------
rereference
"""
contact_arrays = pd.Series([x.rstrip('0123456789') for x in channelnames])
connections = np.zeros((len(contact_arrays),) * 2, dtype=float)
for _, inds in contact_arrays.groupby(contact_arrays):
for i in inds.index:
connections[i, inds.index] = 1.0
def _find_adjacent_numbers(a, b, number, extent):
'''
Used to determine electrodes for local averaging ECoG grid"
'''
# Validate if the number is within the valid range
if number < 1 or number > a * b:
raise ValueError("The number is outside the range of the grid.")

# remove longer than extent if desired
if extent is not None:
if extent < 1:
raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}')
connections *= np.tri(*connections.shape, k=extent)
connections *= np.fliplr(np.flipud(np.tri(*connections.shape, k=extent)))
connections = connections

# Calculate the row and column of the given number
row = (number - 1) // b
col = (number - 1) % b

# Find all adjacent numbers within the extent
adjacent_numbers = []
for dr in range(-extent, extent + 1): # Rows within the extent
for dc in range(-extent, extent + 1): # Columns within the extent
if dr == 0 and dc == 0:
continue # Skip the number itself
new_row, new_col = row + dr, col + dc
if 0 <= new_row < a and 0 <= new_col < b:
adjacent_num = new_row * b + new_col + 1
adjacent_numbers.append(adjacent_num)

return np.array(adjacent_numbers, dtype=int)

connections = np.zeros((len(channelnames),) * 2, dtype=float)
channelnames = np.array(channelnames)
contact_arrays = np.array([x.rstrip('0123456789') for x in channelnames])
contacts = np.unique(contact_arrays)
# Determine the channel numbers on each contact
ch_per_contact = {contact:[int(x.replace(contact,'')) for x in channelnames
if x.rstrip('0123456789')==contact]
for contact in contacts}

if extent is None:
# Common average referencing per electrode array (ECoG grid or sEEG shank)
# CAR will end up subtracting parts of channel ch from itself
for contact in contacts:
for ch in ch_per_contact[contact]:
curr = np.where(channelnames==f'{contact}{ch}')[0]
inds = np.where(contact_arrays==contact)[0]
connections[curr,inds] = 1
elif extent < 1:
raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}')
else:
# Local average referencing within each electrode array
# LAR will NOT subtract parts of channel ch from itself
for contact in contacts:
for ch in ch_per_contact[contact]:
# Local referencing for ECoG grids
if 'grid' in contact.lower():
num_ch = len(ch_per_contact[contact])
side = np.sqrt(num_ch)
half_side = np.sqrt(num_ch/2)
# Check grid_sizes dict
if contact in grid_sizes:
nrows, ncols = grid_sizes[contact]
# Assume a square
elif np.isclose(side, int(side)):
nrows, ncols = side, side
# Assume a 1 x 2 rectangle
elif np.isclose(half_side, int(half_side)):
nrows, ncols = half_side, half_side*2
else:
raise Exception(f'Cannot determine {contact} layout. Please include layout in `grid_sizes`')
adjacent = _find_adjacent_numbers(nrows, ncols, ch, extent)
curr = np.where(channelnames==f'{contact}{ch}')[0]
inds = []
for adj in adjacent:
inds.append(np.where(channelnames==f'{contact}{adj}')[0])

# Local referencing for sEEG shanks and strips
else:
curr = np.where(channelnames==f'{contact}{ch}')[0]
inds = []
for cc in range(ch-extent, ch+extent+1):
if cc != ch:
inds.append(np.where(channelnames==f'{contact}{cc}')[0])

inds = np.concatenate(inds)
if len(inds) < 1:
print(f'{contact}{cc} has no re-references.')
else:
connections[curr,inds] = 1

return connections
25 changes: 22 additions & 3 deletions tests/preprocessing/test_rereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,29 @@


def test_create_contact_rereference_arr():
expected = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]])
g = ['LT1','LT2','RT1','RT2']
arr = make_contact_rereference_arr(g)
expected = np.array([[0,1,0,0,0,0,0,0],
[1,0,0,0,0,0,0,0],
[0,0,0,1,0,0,0,0],
[0,0,1,0,0,0,0,0],
[0,0,0,0,0,1,1,1],
[0,0,0,0,1,0,1,1],
[0,0,0,0,1,1,0,1],
[0,0,0,0,1,1,1,0],
])
expected1 = np.array([[1,1,0,0,0,0,0,0],
[1,1,0,0,0,0,0,0],
[0,0,1,1,0,0,0,0],
[0,0,1,1,0,0,0,0],
[0,0,0,0,1,1,1,1],
[0,0,0,0,1,1,1,1],
[0,0,0,0,1,1,1,1],
[0,0,0,0,1,1,1,1],
])
g = ['LT1','LT2','GridA1','GridA2'] + [f'GridB{n}' for n in range(1,5)]
arr = make_contact_rereference_arr(g, extent=1, grid_sizes={'GridA':(1,2)})
arr1 = make_contact_rereference_arr(g)
assert np.allclose(expected, arr)
assert np.allclose(expected1, arr1)

def test_rereference_avg():
arr = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]])
Expand Down
Loading