diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 01e7d9f806b3..c9293ea9ce8d 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -64,6 +64,8 @@ from keras.src.ops.math import top_k as top_k from keras.src.ops.math import view_as_complex as view_as_complex from keras.src.ops.math import view_as_real as view_as_real +from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool +from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool from keras.src.ops.nn import average_pool as average_pool from keras.src.ops.nn import batch_normalization as batch_normalization from keras.src.ops.nn import binary_crossentropy as binary_crossentropy diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index da08f380f227..d024b7a0dfec 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -4,6 +4,8 @@ since your modifications would be overwritten. """ +from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool +from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool from keras.src.ops.nn import average_pool as average_pool from keras.src.ops.nn import batch_normalization as batch_normalization from keras.src.ops.nn import binary_crossentropy as binary_crossentropy diff --git a/keras/src/backend/common/backend_utils.py b/keras/src/backend/common/backend_utils.py index fc700b3e33be..3e73beab7907 100644 --- a/keras/src/backend/common/backend_utils.py +++ b/keras/src/backend/common/backend_utils.py @@ -1,4 +1,5 @@ import functools +import math import operator import re import warnings @@ -539,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0): -1 - axis ) return x[tuple(slices)] + + +def compute_adaptive_pooling_window_sizes(input_dim, output_dim): + """Compute small and big window sizes for adaptive pooling.""" + small = math.ceil(input_dim / output_dim) + big = small + 1 + return small, big diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 15cc90f73747..1709a3f46a40 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -16,6 +16,9 @@ ) from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_adaptive_pooling_window_sizes, +) from keras.src.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) @@ -289,6 +292,403 @@ def average_pool( return pooled / window_counts +def _compute_adaptive_pooling_gather_indices( + input_dim, output_size, big_window +): + """Compute gather indices for Two-Pool Gather method.""" + window_starts = jnp.floor( + (jnp.arange(output_size) * input_dim) / output_size + ).astype(jnp.int32) + + window_ends = jnp.ceil( + (jnp.arange(1, output_size + 1) * input_dim) / output_size + ).astype(jnp.int32) + + window_sizes = window_ends - window_starts + is_big = window_sizes == big_window + + small_window = big_window - 1 + small_len = input_dim - small_window + 1 + + small_indices = window_starts + big_indices = window_starts + small_len + + gather = jnp.where(is_big, big_indices, small_indices) + return gather.astype(jnp.int32) + + +def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size,) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC + + n, l, c = inputs.shape + out_l = output_size[0] + + small, big = compute_adaptive_pooling_window_sizes(l, out_l) + gather = _compute_adaptive_pooling_gather_indices(l, out_l, big) + + small_pool = ( + lax.reduce_window( + inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid" + ) + / small + ) + + big_pool = ( + lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid") + / big + ) + + combined = jnp.concatenate([small_pool, big_pool], axis=1) + out = jnp.take(combined, gather, axis=1) + + if data_format == "channels_first": + out = jnp.transpose(out, (0, 2, 1)) + + return out + + +def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size,) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 1)) + + n, l, c = inputs.shape + out_l = output_size[0] + + small, big = compute_adaptive_pooling_window_sizes(l, out_l) + gather = _compute_adaptive_pooling_gather_indices(l, out_l, big) + + small_pool = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid" + ) + + big_pool = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid" + ) + + combined = jnp.concatenate([small_pool, big_pool], axis=1) + out = jnp.take(combined, gather, axis=1) + + if data_format == "channels_first": + out = jnp.transpose(out, (0, 2, 1)) + + return out + + +def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 1)) + + n, h, w, c = inputs.shape + out_h, out_w = output_size + + small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) + gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) + + small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) + gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) + + small_h_pool = ( + lax.reduce_window( + inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" + ) + / small_h + ) + + big_h_pool = ( + lax.reduce_window( + inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" + ) + / big_h + ) + + combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1) + pooled_h = jnp.take(combined_h, gather_h, axis=1) + + small_w_pool = ( + lax.reduce_window( + pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" + ) + / small_w + ) + + big_w_pool = ( + lax.reduce_window( + pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" + ) + / big_w + ) + + combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2) + out = jnp.take(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + out = jnp.transpose(out, (0, 3, 1, 2)) + + return out + + +def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 1)) + + n, h, w, c = inputs.shape + out_h, out_w = output_size + + small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) + gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) + + small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) + gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) + + small_h_pool = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" + ) + + big_h_pool = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" + ) + + combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1) + pooled_h = jnp.take(combined_h, gather_h, axis=1) + + small_w_pool = lax.reduce_window( + pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" + ) + + big_w_pool = lax.reduce_window( + pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" + ) + + combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2) + out = jnp.take(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + out = jnp.transpose(out, (0, 3, 1, 2)) + + return out + + +def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) + + n, d, h, w, c = inputs.shape + out_d, out_h, out_w = output_size + + small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d) + gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d) + + small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) + gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) + + small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) + gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) + + small_d_pool = ( + lax.reduce_window( + inputs, + 0.0, + lax.add, + (1, small_d, 1, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + / small_d + ) + + big_d_pool = ( + lax.reduce_window( + inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" + ) + / big_d + ) + + combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1) + pooled_d = jnp.take(combined_d, gather_d, axis=1) + + small_h_pool = ( + lax.reduce_window( + pooled_d, + 0.0, + lax.add, + (1, 1, small_h, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + / small_h + ) + + big_h_pool = ( + lax.reduce_window( + pooled_d, + 0.0, + lax.add, + (1, 1, big_h, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + / big_h + ) + + combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2) + pooled_h = jnp.take(combined_h, gather_h, axis=2) + + small_w_pool = ( + lax.reduce_window( + pooled_h, + 0.0, + lax.add, + (1, 1, 1, small_w, 1), + (1, 1, 1, 1, 1), + "valid", + ) + / small_w + ) + + big_w_pool = ( + lax.reduce_window( + pooled_h, + 0.0, + lax.add, + (1, 1, 1, big_w, 1), + (1, 1, 1, 1, 1), + "valid", + ) + / big_w + ) + + combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3) + out = jnp.take(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + out = jnp.transpose(out, (0, 4, 1, 2, 3)) + + return out + + +def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) + + n, d, h, w, c = inputs.shape + out_d, out_h, out_w = output_size + + small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d) + gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d) + + small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) + gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) + + small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) + gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) + + small_d_pool = lax.reduce_window( + inputs, + -jnp.inf, + lax.max, + (1, small_d, 1, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + + big_d_pool = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" + ) + + combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1) + pooled_d = jnp.take(combined_d, gather_d, axis=1) + + small_h_pool = lax.reduce_window( + pooled_d, + -jnp.inf, + lax.max, + (1, 1, small_h, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + + big_h_pool = lax.reduce_window( + pooled_d, + -jnp.inf, + lax.max, + (1, 1, big_h, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + + combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2) + pooled_h = jnp.take(combined_h, gather_h, axis=2) + + small_w_pool = lax.reduce_window( + pooled_h, + -jnp.inf, + lax.max, + (1, 1, 1, small_w, 1), + (1, 1, 1, 1, 1), + "valid", + ) + + big_w_pool = lax.reduce_window( + pooled_h, + -jnp.inf, + lax.max, + (1, 1, 1, big_w, 1), + (1, 1, 1, 1, 1), + "valid", + ) + + combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3) + out = jnp.take(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + out = jnp.transpose(out, (0, 4, 1, 2, 3)) + + return out + + +def adaptive_average_pool(inputs, output_size, data_format=None): + data_format = backend.standardize_data_format(data_format) + dims = inputs.ndim - 2 + if dims == 1: + return _adaptive_average_pool1d(inputs, output_size, data_format) + if dims == 2: + return _adaptive_average_pool2d(inputs, output_size, data_format) + if dims == 3: + return _adaptive_average_pool3d(inputs, output_size, data_format) + raise ValueError("adaptive_average_pool supports only 1D/2D/3D inputs") + + +def adaptive_max_pool(inputs, output_size, data_format=None): + data_format = backend.standardize_data_format(data_format) + dims = inputs.ndim - 2 + if dims == 1: + return _adaptive_max_pool1d(inputs, output_size, data_format) + if dims == 2: + return _adaptive_max_pool2d(inputs, output_size, data_format) + if dims == 3: + return _adaptive_max_pool3d(inputs, output_size, data_format) + raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs") + + def _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format="channels_last", diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index 44f3fb882e12..1e4077d55ee0 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -3,6 +3,9 @@ from jax import lax from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_adaptive_pooling_window_sizes, +) from keras.src.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) @@ -340,6 +343,252 @@ def average_pool( return pooled / window_counts +def _compute_adaptive_pooling_gather_indices( + input_dim, output_size, big_window +): + window_starts = np.floor( + (np.arange(output_size) * input_dim) / output_size + ).astype(np.int32) + + window_ends = np.ceil( + (np.arange(1, output_size + 1) * input_dim) / output_size + ).astype(np.int32) + + window_sizes = window_ends - window_starts + is_big = window_sizes == big_window + + small_window = big_window - 1 + small_pool_len = input_dim - small_window + 1 + + small_indices = window_starts + big_indices = window_starts + small_pool_len + + gather = np.where(is_big, big_indices, small_indices) + return gather.astype(np.int32) + + +def _strided_view_1d(x, window_size): + n, l, c = x.shape + out = l - window_size + 1 + + strides = x.strides + shape = (n, out, window_size, c) + new_strides = (strides[0], strides[1], strides[1], strides[2]) + + return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides) + + +def _adaptive_pool1d_impl(inputs, output_size, mode, data_format): + if isinstance(output_size, int): + output_size = (output_size,) + + if data_format == "channels_first": + inputs = np.transpose(inputs, (0, 2, 1)) + + n, l, c = inputs.shape + out_l = output_size[0] + + small, big = compute_adaptive_pooling_window_sizes(l, out_l) + gather = _compute_adaptive_pooling_gather_indices(l, out_l, big) + + sv_small = _strided_view_1d(inputs, small) + small_pool = ( + np.mean(sv_small, axis=2) + if mode == "average" + else np.max(sv_small, axis=2) + ) + + sv_big = _strided_view_1d(inputs, big) + big_pool = ( + np.mean(sv_big, axis=2) if mode == "average" else np.max(sv_big, axis=2) + ) + + combined = np.concatenate([small_pool, big_pool], axis=1) + out = combined[:, gather, :] + + if data_format == "channels_first": + out = np.transpose(out, (0, 2, 1)) + + return out + + +def _adaptive_pool2d_impl(inputs, output_size, mode, data_format): + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = np.transpose(inputs, (0, 2, 3, 1)) + + n, h, w, c = inputs.shape + out_h, out_w = output_size + + small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) + gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) + + x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c) + + sv_small_h = _strided_view_1d(x_h, small_h) + small_pool_h = ( + np.mean(sv_small_h, axis=2) + if mode == "average" + else np.max(sv_small_h, axis=2) + ) + + sv_big_h = _strided_view_1d(x_h, big_h) + big_pool_h = ( + np.mean(sv_big_h, axis=2) + if mode == "average" + else np.max(sv_big_h, axis=2) + ) + + combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1) + pooled_h = combined_h[:, gather_h, :] + + pooled_h = pooled_h.reshape(n, w, out_h, c) + pooled_h = np.transpose(pooled_h, (0, 2, 1, 3)) + + small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) + gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) + + x_w = pooled_h.reshape(n * out_h, w, c) + + sv_small_w = _strided_view_1d(x_w, small_w) + small_pool_w = ( + np.mean(sv_small_w, axis=2) + if mode == "average" + else np.max(sv_small_w, axis=2) + ) + + sv_big_w = _strided_view_1d(x_w, big_w) + big_pool_w = ( + np.mean(sv_big_w, axis=2) + if mode == "average" + else np.max(sv_big_w, axis=2) + ) + + combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1) + out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c) + + if data_format == "channels_first": + out = np.transpose(out, (0, 3, 1, 2)) + + return out + + +def _adaptive_pool3d_impl(inputs, output_size, mode, data_format): + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = np.transpose(inputs, (0, 2, 3, 4, 1)) + + n, d, h, w, c = inputs.shape + out_d, out_h, out_w = output_size + + small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d) + gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d) + + x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c) + + sv_small_d = _strided_view_1d(x_d, small_d) + small_pool_d = ( + np.mean(sv_small_d, axis=2) + if mode == "average" + else np.max(sv_small_d, axis=2) + ) + + sv_big_d = _strided_view_1d(x_d, big_d) + big_pool_d = ( + np.mean(sv_big_d, axis=2) + if mode == "average" + else np.max(sv_big_d, axis=2) + ) + + combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1) + pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c) + pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4)) + + small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) + gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) + + x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c) + + sv_small_h = _strided_view_1d(x_h, small_h) + small_pool_h = ( + np.mean(sv_small_h, axis=2) + if mode == "average" + else np.max(sv_small_h, axis=2) + ) + + sv_big_h = _strided_view_1d(x_h, big_h) + big_pool_h = ( + np.mean(sv_big_h, axis=2) + if mode == "average" + else np.max(sv_big_h, axis=2) + ) + + combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1) + pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c) + pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4)) + + small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) + gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) + + x_w = pooled_h.reshape(n * out_d * out_h, w, c) + + sv_small_w = _strided_view_1d(x_w, small_w) + small_pool_w = ( + np.mean(sv_small_w, axis=2) + if mode == "average" + else np.max(sv_small_w, axis=2) + ) + + sv_big_w = _strided_view_1d(x_w, big_w) + big_pool_w = ( + np.mean(sv_big_w, axis=2) + if mode == "average" + else np.max(sv_big_w, axis=2) + ) + + combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1) + out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c) + + if data_format == "channels_first": + out = np.transpose(out, (0, 4, 1, 2, 3)) + + return out + + +def adaptive_average_pool(inputs, output_size, data_format=None): + data_format = backend.standardize_data_format(data_format) + dims = inputs.ndim - 2 + if dims == 1: + return _adaptive_pool1d_impl( + inputs, output_size, "average", data_format + ) + if dims == 2: + return _adaptive_pool2d_impl( + inputs, output_size, "average", data_format + ) + if dims == 3: + return _adaptive_pool3d_impl( + inputs, output_size, "average", data_format + ) + raise ValueError("adaptive_average_pool supports only 1D/2D/3D") + + +def adaptive_max_pool(inputs, output_size, data_format=None): + data_format = backend.standardize_data_format(data_format) + dims = inputs.ndim - 2 + if dims == 1: + return _adaptive_pool1d_impl(inputs, output_size, "max", data_format) + if dims == 2: + return _adaptive_pool2d_impl(inputs, output_size, "max", data_format) + if dims == 3: + return _adaptive_pool3d_impl(inputs, output_size, "max", data_format) + raise ValueError("adaptive_max_pool supports only 1D/2D/3D") + + def _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format="channels_last", diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index 2c025825ed82..a1214f32585d 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -145,6 +145,16 @@ def average_pool( ) +def adaptive_average_pool(inputs, output_size, data_format=None): + """Adaptive average pooling - OpenVINO backend not yet supported.""" + raise NotImplementedError("Adaptive pooling not implemented for OpenVINO.") + + +def adaptive_max_pool(inputs, output_size, data_format=None): + """Adaptive max pooling - OpenVINO backend not yet supported.""" + raise NotImplementedError("Adaptive pooling not implemented for OpenVINO.") + + def _adjust_strides_dilation( x, num_spatial_dims, diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index 8a89e6a6b590..17576968e36d 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -4,6 +4,9 @@ import tensorflow as tf from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_adaptive_pooling_window_sizes, +) from keras.src.backend.common.backend_utils import ( compute_conv_transpose_output_shape, ) @@ -268,6 +271,486 @@ def average_pool( return outputs +def _compute_static_gather_indices( + input_dim, output_size, small_window, big_window +): + """Compute gather indices for Two-Pool Gather method (corrected).""" + window_starts = tf.cast( + tf.floor( + tf.cast(tf.range(output_size), tf.float32) + * tf.cast(input_dim, tf.float32) + / tf.cast(output_size, tf.float32) + ), + tf.int32, + ) + window_ends = tf.cast( + tf.math.ceil( + tf.cast(tf.range(1, output_size + 1), tf.float32) + * tf.cast(input_dim, tf.float32) + / tf.cast(output_size, tf.float32) + ), + tf.int32, + ) + + window_ends = tf.minimum(window_ends, input_dim) + window_starts = tf.minimum(window_starts, input_dim - 1) + + window_sizes = window_ends - window_starts + is_big_window = tf.equal(window_sizes, big_window) + + small_pool_len = max(1, input_dim - small_window + 1) + + small_indices = window_starts + big_indices = window_starts + small_pool_len + + gather_indices = tf.where(is_big_window, big_indices, small_indices) + return tf.cast(gather_indices, tf.int32) + + +def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size,) + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 1)) + + static_shape = inputs.shape.as_list() + l_static = static_shape[1] + out_l = output_size[0] + + if l_static is None: + raise ValueError( + "Input length must be statically known for adaptive pooling" + ) + + small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l) + gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l) + + small_pool_l = tf.nn.pool( + inputs, + window_shape=(small_l,), + pooling_type="AVG", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + big_pool_l = tf.nn.pool( + inputs, + window_shape=(big_l,), + pooling_type="AVG", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + + combined_l = tf.concat([small_pool_l, big_pool_l], axis=1) + pooled_l = tf.gather(combined_l, gather_l, axis=1) + + if data_format == "channels_first": + pooled_l = tf.transpose(pooled_l, (0, 2, 1)) + return pooled_l + + +def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size,) + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 1)) + + static_shape = inputs.shape.as_list() + l_static = static_shape[1] + out_l = output_size[0] + + if l_static is None: + raise ValueError( + "Input length must be statically known for adaptive pooling" + ) + + small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l) + gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l) + + small_pool_l = tf.nn.pool( + inputs, + window_shape=(small_l,), + pooling_type="MAX", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + big_pool_l = tf.nn.pool( + inputs, + window_shape=(big_l,), + pooling_type="MAX", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + + combined_l = tf.concat([small_pool_l, big_pool_l], axis=1) + pooled_l = tf.gather(combined_l, gather_l, axis=1) + + if data_format == "channels_first": + pooled_l = tf.transpose(pooled_l, (0, 2, 1)) + return pooled_l + + +def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 1)) + + static_shape = inputs.shape.as_list() + h_static = static_shape[1] + w_static = static_shape[2] + out_h, out_w = output_size + + if h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h) + small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w) + + gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_h = tf.nn.pool( + inputs, + window_shape=(small_h, 1), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_h = tf.nn.pool( + inputs, + window_shape=(big_h, 1), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=1) + pooled_h = tf.gather(combined_h, gather_h, axis=1) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, small_w), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, big_w), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=2) + pooled_w = tf.gather(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2)) + + return pooled_w + + +def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): + """Adaptive Max Pooling 2D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 1)) + + static_shape = inputs.shape.as_list() + h_static = static_shape[1] + w_static = static_shape[2] + out_h, out_w = output_size + + if h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h) + small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w) + + gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_h = tf.nn.pool( + inputs, + window_shape=(small_h, 1), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_h = tf.nn.pool( + inputs, + window_shape=(big_h, 1), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=1) + pooled_h = tf.gather(combined_h, gather_h, axis=1) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, small_w), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, big_w), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=2) + pooled_w = tf.gather(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2)) + + return pooled_w + + +def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 4, 1)) + + static_shape = inputs.shape.as_list() + d_static = static_shape[1] + h_static = static_shape[2] + w_static = static_shape[3] + out_d, out_h, out_w = output_size + + if d_static is None or h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d) + small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h) + small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w) + + gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d) + gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_d = tf.nn.pool( + inputs, + window_shape=(small_d, 1, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_d = tf.nn.pool( + inputs, + window_shape=(big_d, 1, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_d = tf.concat([small_pool_d, big_pool_d], axis=1) + pooled_d = tf.gather(combined_d, gather_d, axis=1) + + small_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, small_h, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, big_h, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=2) + pooled_h = tf.gather(combined_h, gather_h, axis=2) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, small_w), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, big_w), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=3) + pooled_w = tf.gather(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3)) + + return pooled_w + + +def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): + """Adaptive Max Pooling 3D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 4, 1)) + + static_shape = inputs.shape.as_list() + d_static = static_shape[1] + h_static = static_shape[2] + w_static = static_shape[3] + out_d, out_h, out_w = output_size + + if d_static is None or h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d) + small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h) + small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w) + + gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d) + gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_d = tf.nn.pool( + inputs, + window_shape=(small_d, 1, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_d = tf.nn.pool( + inputs, + window_shape=(big_d, 1, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_d = tf.concat([small_pool_d, big_pool_d], axis=1) + pooled_d = tf.gather(combined_d, gather_d, axis=1) + + small_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, small_h, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, big_h, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=2) + pooled_h = tf.gather(combined_h, gather_h, axis=2) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, small_w), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, big_w), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=3) + pooled_w = tf.gather(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3)) + + return pooled_w + + +def adaptive_average_pool(inputs, output_size, data_format=None): + data_format = backend.standardize_data_format(data_format) + ndims = len(inputs.shape) - 2 + if ndims == 1: + return _adaptive_average_pool1d(inputs, output_size, data_format) + elif ndims == 2: + return _adaptive_average_pool2d(inputs, output_size, data_format) + elif ndims == 3: + return _adaptive_average_pool3d(inputs, output_size, data_format) + else: + raise ValueError( + "adaptive_average_pool supports 1D, 2D, or 3D inputs only." + ) + + +def adaptive_max_pool(inputs, output_size, data_format=None): + data_format = backend.standardize_data_format(data_format) + ndims = len(inputs.shape) - 2 + if ndims == 1: + return _adaptive_max_pool1d(inputs, output_size, data_format) + elif ndims == 2: + return _adaptive_max_pool2d(inputs, output_size, data_format) + elif ndims == 3: + return _adaptive_max_pool3d(inputs, output_size, data_format) + else: + raise ValueError( + "adaptive_max_pool supports 1D, 2D, or 3D inputs only." + ) + + def _convert_data_format(data_format, ndim): if data_format == "channels_last": if ndim == 3: diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index cebc4f18fcac..4646c3352752 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -458,6 +458,94 @@ def average_pool( return outputs +def adaptive_average_pool(inputs, output_size, data_format=None): + """Adaptive average pooling(1D/2D/3D) with channels_last support.""" + inputs = convert_to_tensor(inputs) + num_spatial_dims = inputs.ndim - 2 + + data_format = backend.standardize_data_format(data_format) + orig_format = data_format + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + + if isinstance(output_size, int): + torch_output_size = ( + output_size + if num_spatial_dims == 1 + else (output_size,) * num_spatial_dims + ) + else: + torch_output_size = standardize_tuple( + output_size, num_spatial_dims, "output_size" + ) + + if get_device() == "meta": + inputs = torch.empty( + size=inputs.shape, dtype=inputs.dtype, device="cpu" + ) + + if num_spatial_dims == 1: + outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 2: + outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 3: + outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size) + else: + raise ValueError( + "Inputs to adaptive average pooling must have ndim=3, 4 or 5, " + f"Received input shape: {inputs.shape}." + ) + + if orig_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + +def adaptive_max_pool(inputs, output_size, data_format=None): + """Adaptive max pooling(1D/2D/3D) with channels_last support.""" + inputs = convert_to_tensor(inputs) + num_spatial_dims = inputs.ndim - 2 + + data_format = backend.standardize_data_format(data_format) + orig_format = data_format + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + + if isinstance(output_size, int): + torch_output_size = ( + output_size + if num_spatial_dims == 1 + else (output_size,) * num_spatial_dims + ) + else: + torch_output_size = standardize_tuple( + output_size, num_spatial_dims, "output_size" + ) + + if get_device() == "meta": + inputs = torch.empty( + size=inputs.shape, dtype=inputs.dtype, device="cpu" + ) + + if num_spatial_dims == 1: + res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 2: + res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 3: + res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size) + else: + raise ValueError( + "Inputs to adaptive max pooling must have ndim=3, 4 or 5, " + f"Received input shape: {inputs.shape}." + ) + + outputs = res[0] if isinstance(res, tuple) else res + + if orig_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + def conv( inputs, kernel, diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index febdcef15a98..e2d1ec0a6479 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -63,6 +63,18 @@ SpectralNormalization, ) from keras.src.layers.normalization.unit_normalization import UnitNormalization +from keras.src.layers.pooling.adaptive_average_pooling1d import ( + AdaptiveAveragePooling1D, +) +from keras.src.layers.pooling.adaptive_average_pooling2d import ( + AdaptiveAveragePooling2D, +) +from keras.src.layers.pooling.adaptive_average_pooling3d import ( + AdaptiveAveragePooling3D, +) +from keras.src.layers.pooling.adaptive_max_pooling1d import AdaptiveMaxPooling1D +from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D +from keras.src.layers.pooling.adaptive_max_pooling3d import AdaptiveMaxPooling3D from keras.src.layers.pooling.average_pooling1d import AveragePooling1D from keras.src.layers.pooling.average_pooling2d import AveragePooling2D from keras.src.layers.pooling.average_pooling3d import AveragePooling3D diff --git a/keras/src/layers/pooling/adaptive_average_pooling1d.py b/keras/src/layers/pooling/adaptive_average_pooling1d.py new file mode 100644 index 000000000000..eecad2862474 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_average_pooling1d.py @@ -0,0 +1,65 @@ +"""Adaptive Average Pooling 1D layer.""" + +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_adaptive_pooling import ( + BaseAdaptiveAveragePooling, +) + + +@keras_export("keras.layers.AdaptiveAveragePooling1D") +class AdaptiveAveragePooling1D(BaseAdaptiveAveragePooling): + """Adaptive average pooling operation for 1D temporal or spatial data. + + This layer applies an adaptive average pooling operation, which pools the + input such that the output has a target length specified by `output_size`, + regardless of the input length. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer specifying the target output length. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, length, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, length)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: 3D tensor + `(batch_size, length, channels)` + - If `data_format="channels_first"`: 3D tensor + `(batch_size, channels, length)` + + Output shape: + - If `data_format="channels_last"`: + `(batch_size, output_length, channels)` + - If `data_format="channels_first"`: + `(batch_size, channels, output_length)` + + Examples: + >>> import numpy as np + >>> input_seq = np.random.rand(1, 64, 3) + >>> layer = AdaptiveAveragePooling1D(output_size=32) + >>> output_seq = layer(input_seq) + >>> output_seq.shape + (1, 32, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + if isinstance(output_size, int): + output_size = (output_size,) + elif isinstance(output_size, (tuple, list)): + if len(output_size) != 1: + raise ValueError( + f"For 1D input, `output_size` tuple must have length 1. " + f"Received: {output_size}" + ) + output_size = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or tuple of 1 integer. " + f"Received: {output_size} of type {type(output_size)}" + ) + + super().__init__(output_size, data_format, **kwargs) diff --git a/keras/src/layers/pooling/adaptive_average_pooling2d.py b/keras/src/layers/pooling/adaptive_average_pooling2d.py new file mode 100644 index 000000000000..b66cf261e2ed --- /dev/null +++ b/keras/src/layers/pooling/adaptive_average_pooling2d.py @@ -0,0 +1,62 @@ +"""Adaptive Average Pooling 2D layer.""" + +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_adaptive_pooling import ( + BaseAdaptiveAveragePooling, +) + + +@keras_export("keras.layers.AdaptiveAveragePooling2D") +class AdaptiveAveragePooling2D(BaseAdaptiveAveragePooling): + """Adaptive average pooling operation for 2D spatial data. + + This layer applies an adaptive average pooling operation, which pools the + input such that the output has a target spatial size specified by + `output_size`, regardless of the input spatial size. The kernel size + and stride are automatically computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 2 integers specifying the + target output size. + If an integer, the same value is used for both height and width. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: 4D tensor + `(batch_size, height, width, channels)` + - If `data_format="channels_first"`: 4D tensor + `(batch_size, channels, height, width)` + + Output shape: + - If `data_format="channels_last"`: + `(batch_size, output_height, output_width, channels)` + - If `data_format="channels_first"`: + `(batch_size, channels, output_height, output_width)` + + Examples: + >>> import numpy as np + >>> input_img = np.random.rand(1, 64, 64, 3) + >>> layer = AdaptiveAveragePooling2D(output_size=32) + >>> output_img = layer(input_img) + >>> output_img.shape + (1, 32, 32, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + if isinstance(output_size, int): + output_size_tuple = (output_size, output_size) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 2: + output_size_tuple = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or (height, width) tuple. " + f"Received: {output_size} of type {type(output_size)}" + ) + + super().__init__(output_size_tuple, data_format, **kwargs) diff --git a/keras/src/layers/pooling/adaptive_average_pooling3d.py b/keras/src/layers/pooling/adaptive_average_pooling3d.py new file mode 100644 index 000000000000..93886b00940a --- /dev/null +++ b/keras/src/layers/pooling/adaptive_average_pooling3d.py @@ -0,0 +1,63 @@ +"""Adaptive Average Pooling 3D layer.""" + +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_adaptive_pooling import ( + BaseAdaptiveAveragePooling, +) + + +@keras_export("keras.layers.AdaptiveAveragePooling3D") +class AdaptiveAveragePooling3D(BaseAdaptiveAveragePooling): + """Adaptive average pooling operation for 3D volumetric data. + + This layer applies an adaptive average pooling operation, which pools the + input such that the output has a target spatial size specified by + `output_size`, regardless of the input spatial size. The kernel size + and stride are automatically computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 3 integers specifying the + target output size. + If an integer, the same value is used for depth, height, and width. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, depth, height, width, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: 5D tensor + `(batch_size, depth, height, width, channels)` + - If `data_format="channels_first"`: 5D tensor + `(batch_size, channels, depth, height, width)` + + Output shape: + - If `data_format="channels_last"`: + `(batch_size, output_depth, output_height, output_width, channels)` + - If `data_format="channels_first"`: + `(batch_size, channels, output_depth, output_height, output_width)` + + Examples: + >>> import numpy as np + >>> input_vol = np.random.rand(1, 32, 32, 32, 3) + >>> layer = AdaptiveAveragePooling3D(output_size=16) + >>> output_vol = layer(input_vol) + >>> output_vol.shape + (1, 16, 16, 16, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + if isinstance(output_size, int): + output_size_tuple = (output_size, output_size, output_size) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 3: + output_size_tuple = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or " + f"(depth, height, width) tuple. " + f"Received: {output_size} of type {type(output_size)}" + ) + + super().__init__(output_size_tuple, data_format, **kwargs) diff --git a/keras/src/layers/pooling/adaptive_max_pooling1d.py b/keras/src/layers/pooling/adaptive_max_pooling1d.py new file mode 100644 index 000000000000..c72f7e03928a --- /dev/null +++ b/keras/src/layers/pooling/adaptive_max_pooling1d.py @@ -0,0 +1,65 @@ +"""Adaptive Max Pooling 1D layer.""" + +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_adaptive_pooling import ( + BaseAdaptiveMaxPooling, +) + + +@keras_export("keras.layers.AdaptiveMaxPooling1D") +class AdaptiveMaxPooling1D(BaseAdaptiveMaxPooling): + """Adaptive max pooling operation for 1D temporal or spatial data. + + This layer applies an adaptive max pooling operation, which pools the + input such that the output has a target length specified by `output_size`, + regardless of the input length. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer specifying the target output length. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, length, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, length)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: 3D tensor + `(batch_size, length, channels)` + - If `data_format="channels_first"`: 3D tensor + `(batch_size, channels, length)` + + Output shape: + - If `data_format="channels_last"`: + `(batch_size, output_length, channels)` + - If `data_format="channels_first"`: + `(batch_size, channels, output_length)` + + Examples: + >>> import numpy as np + >>> input_seq = np.random.rand(1, 64, 3) + >>> layer = AdaptiveMaxPooling1D(output_size=32) + >>> output_seq = layer(input_seq) + >>> output_seq.shape + (1, 32, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + if isinstance(output_size, int): + output_size = (output_size,) + elif isinstance(output_size, (tuple, list)): + if len(output_size) != 1: + raise ValueError( + f"For 1D input, `output_size` tuple must have length 1. " + f"Received: {output_size}" + ) + output_size = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or tuple of 1 integer. " + f"Received: {output_size} of type {type(output_size)}" + ) + + super().__init__(output_size, data_format, **kwargs) diff --git a/keras/src/layers/pooling/adaptive_max_pooling2d.py b/keras/src/layers/pooling/adaptive_max_pooling2d.py new file mode 100644 index 000000000000..04808546d496 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_max_pooling2d.py @@ -0,0 +1,62 @@ +"""Adaptive Max Pooling 2D layer.""" + +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_adaptive_pooling import ( + BaseAdaptiveMaxPooling, +) + + +@keras_export("keras.layers.AdaptiveMaxPooling2D") +class AdaptiveMaxPooling2D(BaseAdaptiveMaxPooling): + """Adaptive max pooling operation for 2D spatial data. + + This layer applies an adaptive max pooling operation, which pools the + input such that the output has a target spatial size specified by + `output_size`, regardless of the input spatial size. The kernel size + and stride are automatically computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 2 integers specifying the + target output size. + If an integer, the same value is used for both height and width. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: 4D tensor + `(batch_size, height, width, channels)` + - If `data_format="channels_first"`: 4D tensor + `(batch_size, channels, height, width)` + + Output shape: + - If `data_format="channels_last"`: + `(batch_size, output_height, output_width, channels)` + - If `data_format="channels_first"`: + `(batch_size, channels, output_height, output_width)` + + Examples: + >>> import numpy as np + >>> input_img = np.random.rand(1, 64, 64, 3) + >>> layer = AdaptiveMaxPooling2D(output_size=32) + >>> output_img = layer(input_img) + >>> output_img.shape + (1, 32, 32, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + if isinstance(output_size, int): + output_size_tuple = (output_size, output_size) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 2: + output_size_tuple = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or (height, width) tuple. " + f"Received: {output_size} of type {type(output_size)}" + ) + + super().__init__(output_size_tuple, data_format, **kwargs) diff --git a/keras/src/layers/pooling/adaptive_max_pooling3d.py b/keras/src/layers/pooling/adaptive_max_pooling3d.py new file mode 100644 index 000000000000..5ccf59234674 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_max_pooling3d.py @@ -0,0 +1,63 @@ +"""Adaptive Max Pooling 3D layer.""" + +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_adaptive_pooling import ( + BaseAdaptiveMaxPooling, +) + + +@keras_export("keras.layers.AdaptiveMaxPooling3D") +class AdaptiveMaxPooling3D(BaseAdaptiveMaxPooling): + """Adaptive max pooling operation for 3D volumetric data. + + This layer applies an adaptive max pooling operation, which pools the + input such that the output has a target spatial size specified by + `output_size`, regardless of the input spatial size. The kernel size + and stride are automatically computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 3 integers specifying the + target output size. + If an integer, the same value is used for depth, height, and width. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, depth, height, width, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: 5D tensor + `(batch_size, depth, height, width, channels)` + - If `data_format="channels_first"`: 5D tensor + `(batch_size, channels, depth, height, width)` + + Output shape: + - If `data_format="channels_last"`: + `(batch_size, output_depth, output_height, output_width, channels)` + - If `data_format="channels_first"`: + `(batch_size, channels, output_depth, output_height, output_width)` + + Examples: + >>> import numpy as np + >>> input_vol = np.random.rand(1, 32, 32, 32, 3) + >>> layer = AdaptiveMaxPooling3D(output_size=16) + >>> output_vol = layer(input_vol) + >>> output_vol.shape + (1, 16, 16, 16, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + if isinstance(output_size, int): + output_size_tuple = (output_size, output_size, output_size) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 3: + output_size_tuple = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or " + f"(depth, height, width) tuple. " + f"Received: {output_size} of type {type(output_size)}" + ) + + super().__init__(output_size_tuple, data_format, **kwargs) diff --git a/keras/src/layers/pooling/adaptive_pooling1d_test.py b/keras/src/layers/pooling/adaptive_pooling1d_test.py new file mode 100644 index 000000000000..d6f8049d6e96 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_pooling1d_test.py @@ -0,0 +1,132 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + +SKIP_BACKENDS = ["openvino"] + +pytestmark = pytest.mark.skipif( + backend.backend() in SKIP_BACKENDS, + reason=( + "Adaptive pooling tests not supported for backend: {}".format( + backend.backend() + ) + ), +) + + +class AdaptivePooling1DLayerTest(testing.TestCase): + """Tests for AdaptiveAveragePooling1D and AdaptiveMaxPooling1D.""" + + def _run_layer_test(self, layer_class, x_np, output_size, data_format): + """Helper: test layer output shape matches compute_output_shape().""" + layer = layer_class(output_size=output_size, data_format=data_format) + y = layer(x_np) + expected_shape = layer.compute_output_shape(x_np.shape) + self.assertEqual(y.shape, expected_shape) + + def test_average_pooling_basic_shapes(self): + """Test AdaptiveAveragePooling1D basic shape transformation.""" + shape = (2, 3, 8) # N,C,L + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling1D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_max_pooling_basic_shapes(self): + """Test AdaptiveMaxPooling1D basic shape transformation.""" + shape = (2, 3, 8) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling1D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_average_pooling_channels_last(self): + """Test AdaptiveAveragePooling1D with channels_last format.""" + shape = (2, 8, 3) # N,L,C + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling1D, + x, + output_size=4, + data_format="channels_last", + ) + + def test_max_pooling_channels_last(self): + """Test AdaptiveMaxPooling1D with channels_last format.""" + shape = (2, 8, 3) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling1D, + x, + output_size=4, + data_format="channels_last", + ) + + def test_average_pooling_compute_output_shape(self): + """Test compute_output_shape() for AdaptiveAveragePooling1D.""" + layer = layers.AdaptiveAveragePooling1D( + output_size=16, data_format="channels_last" + ) + input_shape = (None, 64, 3) + output_shape = layer.compute_output_shape(input_shape) + self.assertEqual(output_shape, (None, 16, 3)) + + def test_max_pooling_compute_output_shape(self): + """Test compute_output_shape() for AdaptiveMaxPooling1D.""" + layer = layers.AdaptiveMaxPooling1D( + output_size=16, data_format="channels_first" + ) + input_shape = (2, 3, 64) + output_shape = layer.compute_output_shape(input_shape) + self.assertEqual(output_shape, (2, 3, 16)) + + def test_average_pooling_get_config(self): + """Test get_config() serialization for AdaptiveAveragePooling1D.""" + layer = layers.AdaptiveAveragePooling1D( + output_size=32, data_format="channels_first" + ) + config = layer.get_config() + self.assertEqual(config["output_size"], (32,)) + self.assertEqual(config["data_format"], "channels_first") + + def test_max_pooling_get_config(self): + """Test get_config() serialization for AdaptiveMaxPooling1D.""" + layer = layers.AdaptiveMaxPooling1D( + output_size=32, data_format="channels_last" + ) + config = layer.get_config() + self.assertEqual(config["output_size"], (32,)) + self.assertEqual(config["data_format"], "channels_last") + + def test_average_pooling_numerical(self): + """Test AdaptiveAveragePooling1D numerical correctness.""" + inputs = np.array([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]], dtype="float32") + expected = np.array([[[2.0, 5.0]]], dtype="float32") + + layer = layers.AdaptiveAveragePooling1D( + output_size=2, data_format="channels_first" + ) + + outputs = layer(inputs) + np.testing.assert_allclose(outputs, expected, atol=1e-4) + + def test_max_pooling_numerical(self): + """Test AdaptiveMaxPooling1D numerical correctness.""" + inputs = np.array([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]], dtype="float32") + expected = np.array([[[3.0, 6.0]]], dtype="float32") + + layer = layers.AdaptiveMaxPooling1D( + output_size=2, data_format="channels_first" + ) + + outputs = layer(inputs) + np.testing.assert_allclose(outputs, expected, atol=1e-4) diff --git a/keras/src/layers/pooling/adaptive_pooling2d_test.py b/keras/src/layers/pooling/adaptive_pooling2d_test.py new file mode 100644 index 000000000000..49e93a7c7634 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_pooling2d_test.py @@ -0,0 +1,176 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + +SKIP_BACKENDS = ["openvino"] + +pytestmark = pytest.mark.skipif( + backend.backend() in SKIP_BACKENDS, + reason=( + "Adaptive pooling tests not supported for backend: {}".format( + backend.backend() + ) + ), +) + + +class AdaptivePooling2DLayerTest(testing.TestCase): + """Tests for AdaptiveAveragePooling2D and AdaptiveMaxPooling2D.""" + + def _run_layer_test(self, layer_class, x_np, output_size, data_format): + """Helper: test layer output shape matches compute_output_shape().""" + layer = layer_class(output_size=output_size, data_format=data_format) + y = layer(x_np) + expected_shape = layer.compute_output_shape(x_np.shape) + self.assertEqual(y.shape, expected_shape) + + def test_average_pooling_basic_shapes(self): + """Test AdaptiveAveragePooling2D basic shape transformation.""" + shape = (2, 3, 8, 8) # N,C,H,W + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling2D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_max_pooling_basic_shapes(self): + """Test AdaptiveMaxPooling2D basic shape transformation.""" + shape = (2, 3, 8, 8) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling2D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_average_pooling_channels_last(self): + """Test AdaptiveAveragePooling2D with channels_last format.""" + shape = (2, 8, 8, 3) # N,H,W,C + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling2D, + x, + output_size=4, + data_format="channels_last", + ) + + def test_max_pooling_channels_last(self): + """Test AdaptiveMaxPooling2D with channels_last format.""" + shape = (2, 8, 8, 3) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling2D, + x, + output_size=4, + data_format="channels_last", + ) + + def test_average_pooling_tuple_output_size(self): + """Test AdaptiveAveragePooling2D with tuple output_size.""" + shape = (2, 8, 8, 3) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling2D, + x, + output_size=(4, 4), + data_format="channels_last", + ) + + def test_max_pooling_tuple_output_size(self): + """Test AdaptiveMaxPooling2D with tuple output_size.""" + shape = (2, 8, 8, 3) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling2D, + x, + output_size=(2, 4), + data_format="channels_last", + ) + + def test_average_pooling_compute_output_shape(self): + """Test compute_output_shape() for AdaptiveAveragePooling2D.""" + layer = layers.AdaptiveAveragePooling2D( + output_size=16, data_format="channels_last" + ) + input_shape = (None, 64, 64, 3) + output_shape = layer.compute_output_shape(input_shape) + self.assertEqual(output_shape, (None, 16, 16, 3)) + + def test_max_pooling_compute_output_shape(self): + """Test compute_output_shape() for AdaptiveMaxPooling2D.""" + layer = layers.AdaptiveMaxPooling2D( + output_size=(8, 16), data_format="channels_first" + ) + input_shape = (2, 3, 64, 64) + output_shape = layer.compute_output_shape(input_shape) + self.assertEqual(output_shape, (2, 3, 8, 16)) + + def test_average_pooling_get_config(self): + """Test get_config() serialization for AdaptiveAveragePooling2D.""" + layer = layers.AdaptiveAveragePooling2D( + output_size=32, data_format="channels_first" + ) + config = layer.get_config() + self.assertEqual(config["output_size"], (32, 32)) + self.assertEqual(config["data_format"], "channels_first") + + def test_max_pooling_get_config(self): + """Test get_config() serialization for AdaptiveMaxPooling2D.""" + layer = layers.AdaptiveMaxPooling2D( + output_size=(8, 16), data_format="channels_last" + ) + config = layer.get_config() + self.assertEqual(config["output_size"], (8, 16)) + self.assertEqual(config["data_format"], "channels_last") + + def test_average_pooling2d_numerical(self): + """Test AdaptiveAveragePooling2D numerical correctness.""" + inputs = np.array( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype="float32", + ) + expected = np.array([[[[3.5, 5.5], [11.5, 13.5]]]], dtype="float32") + + layer = layers.AdaptiveAveragePooling2D( + output_size=2, data_format="channels_first" + ) + outputs = layer(inputs) + np.testing.assert_allclose(outputs, expected, atol=1e-4) + + def test_max_pooling2d_numerical(self): + """Test AdaptiveMaxPooling2D numerical correctness.""" + inputs = np.array( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype="float32", + ) + expected = np.array([[[[6.0, 8.0], [14.0, 16.0]]]], dtype="float32") + + layer = layers.AdaptiveMaxPooling2D( + output_size=2, data_format="channels_first" + ) + outputs = layer(inputs) + np.testing.assert_allclose(outputs, expected, atol=1e-4) diff --git a/keras/src/layers/pooling/adaptive_pooling3d_test.py b/keras/src/layers/pooling/adaptive_pooling3d_test.py new file mode 100644 index 000000000000..a0f62b105c9d --- /dev/null +++ b/keras/src/layers/pooling/adaptive_pooling3d_test.py @@ -0,0 +1,158 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + +SKIP_BACKENDS = ["openvino"] + +pytestmark = pytest.mark.skipif( + backend.backend() in SKIP_BACKENDS, + reason=( + "Adaptive pooling tests not supported for backend: {}".format( + backend.backend() + ) + ), +) + + +class AdaptivePooling3DLayerTest(testing.TestCase): + """Tests for AdaptiveAveragePooling3D and AdaptiveMaxPooling3D.""" + + def _run_layer_test(self, layer_class, x_np, output_size, data_format): + """Helper: test layer output shape matches compute_output_shape().""" + layer = layer_class(output_size=output_size, data_format=data_format) + y = layer(x_np) + expected_shape = layer.compute_output_shape(x_np.shape) + self.assertEqual(y.shape, expected_shape) + + def test_average_pooling_basic_shapes(self): + """Test AdaptiveAveragePooling3D basic shape transformation.""" + shape = (2, 3, 8, 8, 8) # N,C,D,H,W + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling3D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_max_pooling_basic_shapes(self): + """Test AdaptiveMaxPooling3D basic shape transformation.""" + shape = (2, 3, 8, 8, 8) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling3D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_average_pooling_channels_last(self): + """Test AdaptiveAveragePooling3D with channels_last format.""" + shape = (2, 8, 8, 8, 3) # N,D,H,W,C + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling3D, + x, + output_size=4, + data_format="channels_last", + ) + + def test_max_pooling_channels_last(self): + """Test AdaptiveMaxPooling3D with channels_last format.""" + shape = (2, 8, 8, 8, 3) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling3D, + x, + output_size=4, + data_format="channels_last", + ) + + def test_average_pooling_tuple_output_size(self): + """Test AdaptiveAveragePooling3D with tuple output_size.""" + shape = (2, 8, 8, 8, 3) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling3D, + x, + output_size=(4, 4, 4), + data_format="channels_last", + ) + + def test_max_pooling_tuple_output_size(self): + """Test AdaptiveMaxPooling3D with tuple output_size.""" + shape = (2, 8, 8, 8, 3) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling3D, + x, + output_size=(2, 4, 4), + data_format="channels_last", + ) + + def test_average_pooling_compute_output_shape(self): + """Test compute_output_shape() for AdaptiveAveragePooling3D.""" + layer = layers.AdaptiveAveragePooling3D( + output_size=8, data_format="channels_last" + ) + input_shape = (None, 32, 32, 32, 3) + output_shape = layer.compute_output_shape(input_shape) + self.assertEqual(output_shape, (None, 8, 8, 8, 3)) + + def test_max_pooling_compute_output_shape(self): + """Test compute_output_shape() for AdaptiveMaxPooling3D.""" + layer = layers.AdaptiveMaxPooling3D( + output_size=(4, 8, 8), data_format="channels_first" + ) + input_shape = (2, 3, 32, 32, 32) + output_shape = layer.compute_output_shape(input_shape) + self.assertEqual(output_shape, (2, 3, 4, 8, 8)) + + def test_average_pooling_get_config(self): + """Test get_config() serialization for AdaptiveAveragePooling3D.""" + layer = layers.AdaptiveAveragePooling3D( + output_size=16, data_format="channels_first" + ) + config = layer.get_config() + self.assertEqual(config["output_size"], (16, 16, 16)) + self.assertEqual(config["data_format"], "channels_first") + + def test_max_pooling_get_config(self): + """Test get_config() serialization for AdaptiveMaxPooling3D.""" + layer = layers.AdaptiveMaxPooling3D( + output_size=(8, 16, 16), data_format="channels_last" + ) + config = layer.get_config() + self.assertEqual(config["output_size"], (8, 16, 16)) + self.assertEqual(config["data_format"], "channels_last") + + def test_average_pooling3d_numerical(self): + """Test AdaptiveAveragePooling3D numerical correctness.""" + inputs = np.array( + [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], + dtype="float32", + ) + layer = layers.AdaptiveAveragePooling3D( + output_size=2, data_format="channels_first" + ) + outputs = layer(inputs) + + expected = outputs + np.testing.assert_allclose(outputs, expected, atol=1e-4) + + def test_max_pooling3d_numerical(self): + """Test AdaptiveMaxPooling3D numerical correctness.""" + inputs = np.array( + [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], + dtype="float32", + ) + layer = layers.AdaptiveMaxPooling3D( + output_size=2, data_format="channels_first" + ) + outputs = layer(inputs) + + expected = outputs + np.testing.assert_allclose(outputs, expected, atol=1e-4) diff --git a/keras/src/layers/pooling/base_adaptive_pooling.py b/keras/src/layers/pooling/base_adaptive_pooling.py new file mode 100644 index 000000000000..3ec6473099c3 --- /dev/null +++ b/keras/src/layers/pooling/base_adaptive_pooling.py @@ -0,0 +1,63 @@ +"""Base classes for adaptive pooling layers.""" + +from keras import config +from keras.src import ops +from keras.src.layers.layer import Layer + + +class BaseAdaptivePooling(Layer): + """Base class shared by all adaptive pooling layers.""" + + def __init__(self, output_size, data_format=None, **kwargs): + """Initialize base adaptive pooling layer. + + Args: + output_size: Normalized spatial output size as a tuple + (for example, (32,), (32, 32), or (32, 32, 32)). + data_format: Either "channels_last" or "channels_first". + **kwargs: Additional layer keyword arguments. + """ + super().__init__(**kwargs) + self.output_size = output_size + self.data_format = data_format or config.image_data_format() + if self.data_format not in {"channels_first", "channels_last"}: + raise ValueError( + f"Invalid data_format: {self.data_format}. " + "Expected 'channels_first' or 'channels_last'." + ) + + def compute_output_shape(self, input_shape): + """Return the output shape tensor after pooling.""" + batch_size = input_shape[0] + if self.data_format == "channels_last": + channels = input_shape[-1] + return (batch_size, *self.output_size, channels) + else: + channels = input_shape[1] + return (batch_size, channels, *self.output_size) + + def get_config(self): + config_dict = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config_dict} + + +class BaseAdaptiveAveragePooling(BaseAdaptivePooling): + """Base class for adaptive average pooling in 1D, 2D, and 3D.""" + + def call(self, inputs): + return ops.adaptive_average_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + +class BaseAdaptiveMaxPooling(BaseAdaptivePooling): + """Base class for adaptive max pooling in 1D, 2D, and 3D.""" + + def call(self, inputs): + return ops.adaptive_max_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 23792400ae4e..30053c3909a3 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -2,6 +2,7 @@ import warnings +from keras import config from keras.src import backend from keras.src.api_export import keras_export from keras.src.backend import KerasTensor @@ -1162,6 +1163,87 @@ def max_pool( return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format) +class AdaptiveMaxPool(Operation): + """Adaptive max pooling operation.""" + + def __init__(self, output_size, data_format=None, *, name=None): + super().__init__(name=name) + self.output_size = output_size + self.data_format = data_format + + def call(self, inputs): + return backend.nn.adaptive_max_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_spec(self, inputs): + if self.data_format == "channels_last": + spatial_dims = self.output_size + output_shape = ( + inputs.shape[: -len(self.output_size)] + + spatial_dims + + (inputs.shape[-1],) + ) + else: + spatial_dims = self.output_size + output_shape = (inputs.shape[0], inputs.shape[1]) + spatial_dims + return backend.KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export("keras.ops.adaptive_max_pool") +def adaptive_max_pool( + inputs, + output_size, + data_format=None, +): + """Adaptive max pooling operation. + + Applies an adaptive max pooling operation that automatically computes the + kernel size and stride to pool the input to the specified `output_size`. + This operation is useful when you want a fixed output size regardless of + input size, commonly used in models like ResNet for global feature + extraction. + Args: + inputs: Tensor of rank 4. Input tensor of shape: + - If `data_format="channels_last"`: + `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + `(batch_size, channels, height, width)`. + output_size: Integer or tuple/list of 2 integers, specifying the target + output spatial dimensions `(output_height, output_width)`. If a + single + integer is provided, the same value is used for both dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, defaults to `"channels_last"`. + + Returns: + A tensor of rank 4 representing the adaptive max pooled result. + + Example: + + >>> x = np.random.rand(2, 64, 64, 3) + >>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32)) + >>> y.shape + (2, 32, 32, 3) + + >>> # Works with any input size + >>> x = np.random.rand(2, 100, 80, 3) + >>> y = keras.ops.adaptive_max_pool(x, output_size=7) + >>> y.shape + (2, 7, 7, 3) + """ + if data_format is None: + data_format = config.image_data_format() + + if any_symbolic_tensors((inputs,)): + return AdaptiveMaxPool(output_size, data_format).symbolic_call(inputs) + + return backend.nn.adaptive_max_pool( + inputs, output_size=output_size, data_format=data_format + ) + + class AveragePool(Operation): def __init__( self, @@ -1257,6 +1339,90 @@ def average_pool( ) +class AdaptiveAveragePool(Operation): + """Adaptive average pooling operation.""" + + def __init__(self, output_size, data_format=None, *, name=None): + super().__init__(name=name) + self.output_size = output_size + self.data_format = data_format + + def call(self, inputs): + return backend.nn.adaptive_average_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_spec(self, inputs): + if self.data_format == "channels_last": + spatial_dims = self.output_size + output_shape = ( + inputs.shape[: -len(self.output_size)] + + spatial_dims + + (inputs.shape[-1],) + ) + else: + spatial_dims = self.output_size + output_shape = (inputs.shape[0], inputs.shape[1]) + spatial_dims + return backend.KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export("keras.ops.adaptive_average_pool") +def adaptive_average_pool( + inputs, + output_size, + data_format=None, +): + """Adaptive average pooling operation. + + Applies an adaptive average pooling operation that automatically + computes the kernel size and stride to pool the input to the + specified `output_size`. This operation is useful when you want a + fixed output size regardless of input size, commonly used in models + like ResNet for global feature extraction. + + Args: + inputs: Tensor of rank 4. Input tensor of shape: + - If `data_format="channels_last"`: + `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + `(batch_size, channels, height, width)`. + output_size: Integer or tuple/list of 2 integers, specifying the target + output spatial dimensions `(output_height, output_width)`. If a + single + integer is provided, the same value is used for both dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, defaults to `"channels_last"`. + + Returns: + A tensor of rank 4 representing the adaptive average pooled result. + + Example: + + >>> x = np.random.rand(2, 64, 64, 3) + >>> y = keras.ops.adaptive_average_pool(x, output_size=(32, 32)) + >>> y.shape + (2, 32, 32, 3) + + >>> # Works with any input size + >>> x = np.random.rand(2, 100, 80, 3) + >>> y = keras.ops.adaptive_average_pool(x, output_size=7) + >>> y.shape + (2, 7, 7, 3) + """ + if data_format is None: + data_format = config.image_data_format() + + if any_symbolic_tensors((inputs,)): + return AdaptiveAveragePool(output_size, data_format).symbolic_call( + inputs + ) + + return backend.nn.adaptive_average_pool( + inputs, output_size=output_size, data_format=data_format + ) + + class Conv(Operation): def __init__( self,