-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Add adaptive pooling (1D, 2D, 3D) support across JAX, TensorFlow, and PyTorch backends #21820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 13 commits
f99cc63
f830e93
9938ef1
323a1ab
df57227
5343b71
4cc8ac0
12edcb4
248773f
53a5dc9
2727a24
2a94421
edcf848
1603dd9
19e3045
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1464,3 +1464,368 @@ def _pair(x): | |
| # ---- reshape -> (N, C*kH*kW, L) ---- | ||
| _, CKK, oH, oW = patches.shape | ||
| return patches.reshape(N, CKK, oH * oW) | ||
|
|
||
|
|
||
| def get_static_window_sizes(input_dim, output_dim): | ||
| """Calculate small and big window sizes for adaptive pooling.""" | ||
| small_window = math.ceil(input_dim / output_dim) | ||
| big_window = small_window + 1 | ||
| return small_window, big_window | ||
|
||
|
|
||
|
|
||
| def compute_static_gather_indices(input_dim, output_size, big_window): | ||
|
||
| """Compute gather indices for Two-Pool Gather method.""" | ||
| window_starts = jnp.floor( | ||
| (jnp.arange(output_size) * input_dim) / output_size | ||
| ).astype(jnp.int32) | ||
|
|
||
| window_ends = jnp.ceil( | ||
| (jnp.arange(1, output_size + 1) * input_dim) / output_size | ||
| ).astype(jnp.int32) | ||
|
|
||
| window_sizes = window_ends - window_starts | ||
| is_big_window = window_sizes == big_window | ||
|
|
||
| small_window = big_window - 1 | ||
| small_pool_len = input_dim - small_window + 1 | ||
|
|
||
| small_indices = window_starts | ||
| big_indices = window_starts + small_pool_len | ||
|
|
||
| gather_indices = jnp.where(is_big_window, big_indices, small_indices) | ||
| return gather_indices.astype(jnp.int32) | ||
|
|
||
|
|
||
| # ---------- 1D Adaptive Pooling ---------- | ||
|
||
| def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): | ||
|
||
| """Adaptive Average Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
Comment on lines
+1504
to
+1505
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable names For example: Style Guide ReferencesFootnotes
|
||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| small_pool_l = small_pool_l / small_l | ||
|
|
||
| big_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = big_pool_l / big_l | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| # ---------- 2D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- 3D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_d = small_pool_d / small_d | ||
|
|
||
| big_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_d = big_pool_d / big_d | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, small_d, 1, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_d = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, small_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, big_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, small_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, big_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- Dispatcher ---------- | ||
| def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive average pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_avg_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_avg_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_avg_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_avg_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive max pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_max_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_max_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_max_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_max_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1237,3 +1237,19 @@ def _pair(x): | |
|
|
||
| # ---- reshape -> (N, C*kH*kW, L) ---- | ||
| return patches.reshape(N, C * k[0] * k[1], -1) | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format=None): | ||
| """Adaptive max pooling - Numpy backend not yet supported.""" | ||
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for Numpy. " | ||
| "Use JAX, Torch or Tensorflow backend." | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_avg_pool(inputs, output_size, data_format=None): | ||
| """Adaptive average pooling - Numpy backend not yet supported.""" | ||
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for Numpy. " | ||
| "Use JAX, Torch or Tensorflow backend." | ||
| ) | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this file.