Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion flow_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ def flow_field(
post_patch_size: int | Sequence[int] | None = None,
pre_targeting_field: np.ndarray | None = None,
pre_targeting_step: int | Sequence[int] | None = None,
post_targeting_field: np.ndarray | None = None,
post_targeting_step: int | Sequence[int] | None = None,
progress_fn: Callable[[list[T]], Iterator[T]] = _silent_fn,
):
"""Computes the flow field from post to pre.
Expand Down Expand Up @@ -516,6 +518,10 @@ def flow_field(
'flow_field'
pre_targeting_step: step size at which 'pre_targeting_field' values were
sampled (same units as 'step', yx order)
post_targeting_field: like 'pre_targeting_field', but for shifting the
'post_image' patches
post_targeting_step: step size at which 'post_targeting_field' values were
sampled (same units as 'step', yx order)
progress_fn: function taking a list of batches of 'post' z[yx] start
positions to process; can be used with tqdm to track progress

Expand Down Expand Up @@ -642,7 +648,36 @@ def flow_field(

pre_starts = pre_starts + tg_offsets

post_offsets = None
if post_targeting_field is not None and post_targeting_step is not None:
post_center = (np.array(post_patch_size) // 2).reshape((1, -1))
tg_step = np.array(post_targeting_step).reshape((1, -1))
query = np.round((post_starts + post_center) / tg_step)
query = query.astype(int) # [b, [z]yx]
q = []
for i in range(query.shape[-1]):
q.append(
np.clip(query[:, i], 0, post_targeting_field.shape[i + 1] - 1)
)

field_indexer = (slice(None),) + tuple(q)
post_offsets = np.nan_to_num((post_targeting_field[field_indexer].T))
post_offsets = post_offsets.astype(int)[
:, ::-1
] # [b, xy[z]] -> [b, [z]yx]
new_starts = post_starts + post_offsets

# Clip offsets that would cause the 'post' patch to go out of bounds.
post_offsets = post_offsets - np.minimum(new_starts, 0)
img_shape = np.array(post_image.shape)[None, ...]
new_ends = new_starts + np.array(post_patch_size)[None, ...]
overshoot = np.maximum(new_ends, img_shape) - img_shape
post_offsets = post_offsets - overshoot

post_starts = post_starts + post_offsets

pre_starts = np.clip(pre_starts, 0, np.inf).astype(int)
post_starts = np.clip(post_starts, 0, np.inf).astype(int)

logging.info('.. estimating %d patches.', len(pos_zyx))
peaks = np.array(
Expand All @@ -666,7 +701,11 @@ def flow_field(
for i, coord in enumerate(pos_zyx):
v = peaks[i]
if tg_offsets is not None:
v[:d] = v[:d] + tg_offsets[i, ::-1]
v[:d] = v[:d] + tg_offsets[i, ::-1] # xy[z]

if post_offsets is not None:
v[:d] = v[:d] - post_offsets[i, ::-1] # xy[z]

output[np.index_exp[:] + tuple(coord)] = v

logging.info('Flow field estimation complete.')
Expand Down
4 changes: 2 additions & 2 deletions mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ def fire_step(t, state):


def relax_mesh(
x: jnp.ndarray,
prev: jnp.ndarray | None,
x: jax.Array,
prev: jax.Array | None,
config: IntegrationConfig,
mesh_force=inplane_force,
prev_fn=None,
Expand Down
31 changes: 31 additions & 0 deletions tests/flow_field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,37 @@ def test_jax_peak(self):
self.assertEqual(peaks[0, 2], peak_max / peak_support) # sharpness
self.assertEqual(peaks[0, 3], 0) # peak ratio

def test_post_targeting(self):
pre_image = np.zeros((120, 120), dtype=np.uint8)
post_image = np.zeros((120, 120), dtype=np.uint8)

pre_image[50, 55] = 255
post_image[100, 100] = 255

calculator = flow_field.JAXMaskedXCorrWithStatsCalculator()

# Without targeting, the features are too far apart to be picked up.
field = calculator.flow_field(
pre_image, post_image, patch_size=80, step=40, batch_size=4)
np.testing.assert_array_equal(np.isnan(field[:, 0, 0]), True)

post_targeting_field = np.full((2, 2, 2), 40.0, dtype=np.float32)

# With targeting, a flow field of magnitude larger than the
# normally possible max of patch_size // 2 can be estimated.
field = calculator.flow_field(
pre_image,
post_image,
patch_size=80,
step=40,
batch_size=4,
post_targeting_field=post_targeting_field,
post_targeting_step=40)

np.testing.assert_array_equal([4, 2, 2], field.shape)
np.testing.assert_array_equal(-45 * np.ones((2, 2)), field[0, ...])
np.testing.assert_array_equal(-50 * np.ones((2, 2)), field[1, ...])


if __name__ == '__main__':
absltest.main()