From cb199b60b43f172e7598b397d23ecd3a7d4c4c72 Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Mon, 12 Jan 2026 03:07:56 -0800 Subject: [PATCH] Add support for post_targeting_field in flow estimation. PiperOrigin-RevId: 855147310 --- flow_field.py | 41 +++++++++++++++++++++++++++++++++++++++- mesh.py | 4 ++-- tests/flow_field_test.py | 31 ++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/flow_field.py b/flow_field.py index 94f0788..da565d1 100644 --- a/flow_field.py +++ b/flow_field.py @@ -486,6 +486,8 @@ def flow_field( post_patch_size: int | Sequence[int] | None = None, pre_targeting_field: np.ndarray | None = None, pre_targeting_step: int | Sequence[int] | None = None, + post_targeting_field: np.ndarray | None = None, + post_targeting_step: int | Sequence[int] | None = None, progress_fn: Callable[[list[T]], Iterator[T]] = _silent_fn, ): """Computes the flow field from post to pre. @@ -516,6 +518,10 @@ def flow_field( 'flow_field' pre_targeting_step: step size at which 'pre_targeting_field' values were sampled (same units as 'step', yx order) + post_targeting_field: like 'pre_targeting_field', but for shifting the + 'post_image' patches + post_targeting_step: step size at which 'post_targeting_field' values were + sampled (same units as 'step', yx order) progress_fn: function taking a list of batches of 'post' z[yx] start positions to process; can be used with tqdm to track progress @@ -642,7 +648,36 @@ def flow_field( pre_starts = pre_starts + tg_offsets + post_offsets = None + if post_targeting_field is not None and post_targeting_step is not None: + post_center = (np.array(post_patch_size) // 2).reshape((1, -1)) + tg_step = np.array(post_targeting_step).reshape((1, -1)) + query = np.round((post_starts + post_center) / tg_step) + query = query.astype(int) # [b, [z]yx] + q = [] + for i in range(query.shape[-1]): + q.append( + np.clip(query[:, i], 0, post_targeting_field.shape[i + 1] - 1) + ) + + field_indexer = (slice(None),) + tuple(q) + post_offsets = np.nan_to_num((post_targeting_field[field_indexer].T)) + post_offsets = post_offsets.astype(int)[ + :, ::-1 + ] # [b, xy[z]] -> [b, [z]yx] + new_starts = post_starts + post_offsets + + # Clip offsets that would cause the 'post' patch to go out of bounds. + post_offsets = post_offsets - np.minimum(new_starts, 0) + img_shape = np.array(post_image.shape)[None, ...] + new_ends = new_starts + np.array(post_patch_size)[None, ...] + overshoot = np.maximum(new_ends, img_shape) - img_shape + post_offsets = post_offsets - overshoot + + post_starts = post_starts + post_offsets + pre_starts = np.clip(pre_starts, 0, np.inf).astype(int) + post_starts = np.clip(post_starts, 0, np.inf).astype(int) logging.info('.. estimating %d patches.', len(pos_zyx)) peaks = np.array( @@ -666,7 +701,11 @@ def flow_field( for i, coord in enumerate(pos_zyx): v = peaks[i] if tg_offsets is not None: - v[:d] = v[:d] + tg_offsets[i, ::-1] + v[:d] = v[:d] + tg_offsets[i, ::-1] # xy[z] + + if post_offsets is not None: + v[:d] = v[:d] - post_offsets[i, ::-1] # xy[z] + output[np.index_exp[:] + tuple(coord)] = v logging.info('Flow field estimation complete.') diff --git a/mesh.py b/mesh.py index b0f616f..432f6f1 100644 --- a/mesh.py +++ b/mesh.py @@ -522,8 +522,8 @@ def fire_step(t, state): def relax_mesh( - x: jnp.ndarray, - prev: jnp.ndarray | None, + x: jax.Array, + prev: jax.Array | None, config: IntegrationConfig, mesh_force=inplane_force, prev_fn=None, diff --git a/tests/flow_field_test.py b/tests/flow_field_test.py index 07f6088..500a723 100644 --- a/tests/flow_field_test.py +++ b/tests/flow_field_test.py @@ -93,6 +93,37 @@ def test_jax_peak(self): self.assertEqual(peaks[0, 2], peak_max / peak_support) # sharpness self.assertEqual(peaks[0, 3], 0) # peak ratio + def test_post_targeting(self): + pre_image = np.zeros((120, 120), dtype=np.uint8) + post_image = np.zeros((120, 120), dtype=np.uint8) + + pre_image[50, 55] = 255 + post_image[100, 100] = 255 + + calculator = flow_field.JAXMaskedXCorrWithStatsCalculator() + + # Without targeting, the features are too far apart to be picked up. + field = calculator.flow_field( + pre_image, post_image, patch_size=80, step=40, batch_size=4) + np.testing.assert_array_equal(np.isnan(field[:, 0, 0]), True) + + post_targeting_field = np.full((2, 2, 2), 40.0, dtype=np.float32) + + # With targeting, a flow field of magnitude larger than the + # normally possible max of patch_size // 2 can be estimated. + field = calculator.flow_field( + pre_image, + post_image, + patch_size=80, + step=40, + batch_size=4, + post_targeting_field=post_targeting_field, + post_targeting_step=40) + + np.testing.assert_array_equal([4, 2, 2], field.shape) + np.testing.assert_array_equal(-45 * np.ones((2, 2)), field[0, ...]) + np.testing.assert_array_equal(-50 * np.ones((2, 2)), field[1, ...]) + if __name__ == '__main__': absltest.main()