From f6200b59cd19d78059131ffe5ee8f3acf94b8601 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 3 Dec 2025 09:56:45 -0800 Subject: [PATCH] A simplification / fix for DropBlock2d. Handles even kernel sizes properly thanks to https://github.com/crutcher. For simplicty don't bother keeping drop blocks within feature maps. Remove 'batchwise' option. Remove separate 'fast' option. --- tests/test_layers_drop.py | 251 ++++++++++++++++++++++++++++++++++++++ timm/layers/drop.py | 170 +++++++++++++------------- 2 files changed, 339 insertions(+), 82 deletions(-) create mode 100644 tests/test_layers_drop.py diff --git a/tests/test_layers_drop.py b/tests/test_layers_drop.py new file mode 100644 index 0000000000..5aae7aef51 --- /dev/null +++ b/tests/test_layers_drop.py @@ -0,0 +1,251 @@ +"""Tests for timm.layers.drop module (DropBlock, DropPath).""" +import torch +import pytest + +from timm.layers.drop import drop_block_2d, DropBlock2d, drop_path, DropPath + + +class TestDropBlock2d: + """Test drop_block_2d function and DropBlock2d module.""" + + def test_drop_block_2d_output_shape(self): + """Test that output shape matches input shape.""" + for h, w in [(7, 7), (4, 8), (10, 5), (3, 3)]: + x = torch.ones((2, 3, h, w)) + result = drop_block_2d(x, drop_prob=0.1, block_size=3) + assert result.shape == x.shape, f"Shape mismatch for input ({h}, {w})" + + def test_drop_block_2d_no_drop_when_prob_zero(self): + """Test that no dropping occurs when drop_prob=0.""" + x = torch.ones((2, 3, 8, 8)) + result = drop_block_2d(x, drop_prob=0.0, block_size=3) + assert torch.allclose(result, x) + + def test_drop_block_2d_approximate_keep_ratio(self): + """Test that the drop ratio is approximately correct.""" + torch.manual_seed(123) + # Use large batch for statistical stability + x = torch.ones((32, 16, 56, 56)) + drop_prob = 0.1 + + # With scale_by_keep=False, kept values stay at 1.0 and dropped are 0.0 + # so we can directly measure the drop ratio + result = drop_block_2d(x, drop_prob=drop_prob, block_size=7, scale_by_keep=False) + + total_elements = result.numel() + dropped_elements = (result == 0).sum().item() + actual_drop_ratio = dropped_elements / total_elements + + # Allow some tolerance since it's stochastic + assert abs(actual_drop_ratio - drop_prob) < 0.03, \ + f"Drop ratio {actual_drop_ratio:.3f} not close to expected {drop_prob}" + + def test_drop_block_2d_inplace(self): + """Test inplace operation.""" + x = torch.ones((2, 3, 8, 8)) + x_clone = x.clone() + torch.manual_seed(42) + result = drop_block_2d(x_clone, drop_prob=0.3, block_size=3, inplace=True) + assert result is x_clone, "Inplace should return the same tensor" + + def test_drop_block_2d_couple_channels_true(self): + """Test couple_channels=True uses same mask for all channels.""" + torch.manual_seed(42) + x = torch.ones((2, 4, 16, 16)) + result = drop_block_2d(x, drop_prob=0.3, block_size=5, couple_channels=True) + + # With couple_channels=True, all channels should have same drop pattern + for b in range(x.shape[0]): + mask_c0 = (result[b, 0] == 0).float() + for c in range(1, x.shape[1]): + mask_c = (result[b, c] == 0).float() + assert torch.allclose(mask_c0, mask_c), f"Channel {c} has different mask than channel 0" + + def test_drop_block_2d_couple_channels_false(self): + """Test couple_channels=False uses independent mask per channel.""" + torch.manual_seed(42) + x = torch.ones((2, 4, 16, 16)) + result = drop_block_2d(x, drop_prob=0.3, block_size=5, couple_channels=False) + + # With couple_channels=False, channels should have different patterns + # (with high probability for reasonable drop_prob) + mask_c0 = (result[0, 0] == 0).float() + mask_c1 = (result[0, 1] == 0).float() + # They might occasionally be the same by chance, but very unlikely + assert not torch.allclose(mask_c0, mask_c1), "Channels should have independent masks" + + def test_drop_block_2d_with_noise(self): + """Test with_noise option adds gaussian noise to dropped regions.""" + torch.manual_seed(42) + x = torch.ones((2, 3, 16, 16)) + result = drop_block_2d(x, drop_prob=0.3, block_size=5, with_noise=True) + + # With noise, dropped regions should have non-zero values from gaussian noise + # The result should contain values other than the scaled kept values + unique_vals = torch.unique(result) + assert len(unique_vals) > 2, "With noise should produce varied values" + + def test_drop_block_2d_even_block_size(self): + """Test that even block sizes work correctly.""" + x = torch.ones((2, 3, 16, 16)) + for block_size in [2, 4, 6]: + result = drop_block_2d(x, drop_prob=0.1, block_size=block_size) + assert result.shape == x.shape, f"Shape mismatch for block_size={block_size}" + + def test_drop_block_2d_asymmetric_input(self): + """Test with asymmetric H != W inputs.""" + for h, w in [(8, 16), (16, 8), (7, 14), (14, 7)]: + x = torch.ones((2, 3, h, w)) + result = drop_block_2d(x, drop_prob=0.1, block_size=5) + assert result.shape == x.shape, f"Shape mismatch for ({h}, {w})" + + def test_drop_block_2d_scale_by_keep(self): + """Test scale_by_keep parameter.""" + torch.manual_seed(42) + x = torch.ones((2, 3, 16, 16)) + + # With scale_by_keep=True (default), kept values are scaled up + result_scaled = drop_block_2d(x.clone(), drop_prob=0.3, block_size=5, scale_by_keep=True) + kept_vals_scaled = result_scaled[result_scaled > 0] + # Scaled values should be > 1.0 (scaled up to compensate for drops) + assert kept_vals_scaled.min() > 1.0, "Scaled values should be > 1.0" + + # With scale_by_keep=False, kept values stay at original + torch.manual_seed(42) + result_unscaled = drop_block_2d(x.clone(), drop_prob=0.3, block_size=5, scale_by_keep=False) + kept_vals_unscaled = result_unscaled[result_unscaled > 0] + # Unscaled values should be exactly 1.0 + assert torch.allclose(kept_vals_unscaled, torch.ones_like(kept_vals_unscaled)), \ + "Unscaled values should be 1.0" + + +class TestDropBlock2dModule: + """Test DropBlock2d nn.Module.""" + + def test_deprecated_args_accepted(self): + """Test that deprecated args (batchwise, fast) are silently accepted.""" + # These should not raise + module1 = DropBlock2d(drop_prob=0.1, batchwise=True) + module2 = DropBlock2d(drop_prob=0.1, fast=False) + module3 = DropBlock2d(drop_prob=0.1, batchwise=False, fast=True) + assert module1.drop_prob == 0.1 + assert module2.drop_prob == 0.1 + assert module3.drop_prob == 0.1 + + def test_unknown_args_warned(self): + """Test that unknown kwargs emit a warning.""" + with pytest.warns(UserWarning, match="unexpected keyword argument 'unknown_arg'"): + DropBlock2d(drop_prob=0.1, unknown_arg=True) + + def test_training_mode(self): + """Test that dropping only occurs in training mode.""" + module = DropBlock2d(drop_prob=0.5, block_size=3) + x = torch.ones((2, 3, 8, 8)) + + # In eval mode, should return input unchanged + module.eval() + result = module(x) + assert torch.allclose(result, x), "Should not drop in eval mode" + + # In train mode, should modify input + module.train() + torch.manual_seed(42) + result = module(x) + assert not torch.allclose(result, x), "Should drop in train mode" + + def test_couple_channels_parameter(self): + """Test couple_channels parameter is passed through.""" + x = torch.ones((2, 4, 16, 16)) + + # couple_channels=True (default) + module_coupled = DropBlock2d(drop_prob=0.3, block_size=5, couple_channels=True) + module_coupled.train() + torch.manual_seed(42) + result_coupled = module_coupled(x) + + # All channels should have same pattern + mask_c0 = (result_coupled[0, 0] == 0).float() + mask_c1 = (result_coupled[0, 1] == 0).float() + assert torch.allclose(mask_c0, mask_c1) + + # couple_channels=False + module_uncoupled = DropBlock2d(drop_prob=0.3, block_size=5, couple_channels=False) + module_uncoupled.train() + torch.manual_seed(42) + result_uncoupled = module_uncoupled(x) + + # Channels should have different patterns + mask_c0 = (result_uncoupled[0, 0] == 0).float() + mask_c1 = (result_uncoupled[0, 1] == 0).float() + assert not torch.allclose(mask_c0, mask_c1) + + +class TestDropPath: + """Test drop_path function and DropPath module.""" + + def test_no_drop_when_prob_zero(self): + """Test that no dropping occurs when drop_prob=0.""" + x = torch.ones((4, 8, 16, 16)) + result = drop_path(x, drop_prob=0.0, training=True) + assert torch.allclose(result, x) + + def test_no_drop_when_not_training(self): + """Test that no dropping occurs when not training.""" + x = torch.ones((4, 8, 16, 16)) + result = drop_path(x, drop_prob=0.5, training=False) + assert torch.allclose(result, x) + + def test_drop_path_scaling(self): + """Test that scale_by_keep properly scales kept paths.""" + torch.manual_seed(42) + x = torch.ones((100, 8, 4, 4)) # Large batch for statistical stability + keep_prob = 0.8 + drop_prob = 1 - keep_prob + + result = drop_path(x, drop_prob=drop_prob, training=True, scale_by_keep=True) + + # Kept samples should be scaled by 1/keep_prob = 1.25 + kept_mask = (result[:, 0, 0, 0] != 0) + if kept_mask.any(): + kept_vals = result[kept_mask, 0, 0, 0] + expected_scale = 1.0 / keep_prob + assert torch.allclose(kept_vals, torch.full_like(kept_vals, expected_scale), atol=1e-5) + + def test_drop_path_no_scaling(self): + """Test that scale_by_keep=False does not scale.""" + torch.manual_seed(42) + x = torch.ones((100, 8, 4, 4)) + result = drop_path(x, drop_prob=0.2, training=True, scale_by_keep=False) + + # Kept samples should remain at 1.0 + kept_mask = (result[:, 0, 0, 0] != 0) + if kept_mask.any(): + kept_vals = result[kept_mask, 0, 0, 0] + assert torch.allclose(kept_vals, torch.ones_like(kept_vals)) + + +class TestDropPathModule: + """Test DropPath nn.Module.""" + + def test_training_mode(self): + """Test that dropping only occurs in training mode.""" + module = DropPath(drop_prob=0.5) + x = torch.ones((32, 8, 4, 4)) # Larger batch for statistical reliability + + module.eval() + result = module(x) + assert torch.allclose(result, x), "Should not drop in eval mode" + + module.train() + torch.manual_seed(42) + result = module(x) + # With 50% drop prob on 32 samples, very unlikely all survive + # Check that at least one sample has zeros (was dropped) + has_zeros = (result == 0).any() + assert has_zeros, "Should drop some paths in train mode" + + def test_extra_repr(self): + """Test extra_repr for nice printing.""" + module = DropPath(drop_prob=0.123) + repr_str = module.extra_repr() + assert "0.123" in repr_str diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 73a6e1dea9..928c6b26e8 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -20,104 +20,97 @@ import torch.nn as nn import torch.nn.functional as F -from .grid import ndgrid - def drop_block_2d( - x, + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, - batchwise: bool = False + couple_channels: bool = True, + scale_by_keep: bool = True, ): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf - DropBlock with an experimental gaussian noise option. This layer has been tested on a few training - runs with success, but needs further validation and possibly optimization for lower runtime impact. - """ - B, C, H, W = x.shape - total_size = W * H - clipped_block_size = min(block_size, min(W, H)) - # seed_drop_rate, the gamma parameter - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) - - # Forces the block to be inside the feature map. - w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device)) - valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ - ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) - valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) - - if batchwise: - # one mask for whole batch, quite a bit faster - uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) - else: - uniform_noise = torch.rand_like(x) - block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) - block_mask = -F.max_pool2d( - -block_mask, - kernel_size=clipped_block_size, # block_size, - stride=1, - padding=clipped_block_size // 2) + DropBlock with an experimental gaussian noise option. - if with_noise: - normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) - if inplace: - x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) - else: - x = x * block_mask + normal_noise * (1 - block_mask) - else: - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) - if inplace: - x.mul_(block_mask * normalize_scale) - else: - x = x * block_mask * normalize_scale - return x - - -def drop_block_fast_2d( - x: torch.Tensor, - drop_prob: float = 0.1, - block_size: int = 7, - gamma_scale: float = 1.0, - with_noise: bool = False, - inplace: bool = False, -): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + Args: + x: Input tensor of shape (B, C, H, W). + drop_prob: Probability of dropping a block. + block_size: Size of the block to drop. + gamma_scale: Scale factor for the drop probability. + with_noise: If True, add gaussian noise to dropped regions instead of zeros. + inplace: If True, perform operation in-place. + couple_channels: If True, all channels share the same drop mask (per the original paper). + If False, each channel gets an independent mask. + scale_by_keep: If True, scale kept activations to maintain expected values. - DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid - block mask at edges. + Returns: + Tensor with dropped blocks, same shape as input. """ B, C, H, W = x.shape - total_size = W * H - clipped_block_size = min(block_size, min(W, H)) - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) - - block_mask = torch.empty_like(x).bernoulli_(gamma) - block_mask = F.max_pool2d( - block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + kh, kw = min(block_size, H), min(block_size, W) + + # Compute gamma (seed drop rate) - probability of dropping each spatial location + gamma = float(gamma_scale * drop_prob * H * W) / float(kh * kw) / float((H - kh + 1) * (W - kw + 1)) + + # Generate drop mask: 1 at block centers to drop, 0 elsewhere + # couple_channels=True means all channels share same spatial mask (matches paper) + noise_shape = (B, 1 if couple_channels else C, H, W) + with torch.no_grad(): + block_mask = torch.empty(noise_shape, dtype=x.dtype, device=x.device).bernoulli_(gamma) + + # Expand block centers to full blocks using max pooling + block_mask = F.max_pool2d( + block_mask, + kernel_size=(kh, kw), + stride=1, + padding=(kh // 2, kw // 2), + ) + # Handle even kernel sizes - max_pool2d output is 1 larger in each even dimension + if kh % 2 == 0 or kw % 2 == 0: + # Fix for even kernels proposed by https://github.com/crutcher + block_mask = block_mask[..., (kh + 1) % 2:, (kw + 1) % 2:] + + keep_mask = 1. - block_mask if with_noise: - normal_noise = torch.empty_like(x).normal_() + with torch.no_grad(): + noise = torch.empty_like(keep_mask).normal_() + noise.mul_(block_mask) + if inplace: - x.mul_(1. - block_mask).add_(normal_noise * block_mask) + x.mul_(keep_mask).add_(noise) else: - x = x * (1. - block_mask) + normal_noise * block_mask + x = x * keep_mask + noise else: - block_mask = 1 - block_mask - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) + if scale_by_keep: + with torch.no_grad(): + # Normalize to maintain expected values (scale up kept activations) + normalize_scale = keep_mask.numel() / keep_mask.to(dtype=torch.float32).sum().add(1e-7) + keep_mask.mul_(normalize_scale.to(x.dtype)) + if inplace: - x.mul_(block_mask * normalize_scale) + x.mul_(keep_mask) else: - x = x * block_mask * normalize_scale + x = x * keep_mask + return x class DropBlock2d(nn.Module): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + Args: + drop_prob: Probability of dropping a block. + block_size: Size of the block to drop. + gamma_scale: Scale factor for the drop probability. + with_noise: If True, add gaussian noise to dropped regions instead of zeros. + inplace: If True, perform operation in-place. + couple_channels: If True, all channels share the same drop mask (per the original paper). + If False, each channel gets an independent mask. + scale_by_keep: If True, scale kept activations to maintain expected values. """ def __init__( @@ -127,26 +120,39 @@ def __init__( gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, - batchwise: bool = False, - fast: bool = True): + couple_channels: bool = True, + scale_by_keep: bool = True, + **kwargs, + ): super().__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale self.block_size = block_size self.with_noise = with_noise self.inplace = inplace - self.batchwise = batchwise - self.fast = fast # FIXME finish comparisons of fast vs not + self.couple_channels = couple_channels + self.scale_by_keep = scale_by_keep + + # Backwards compatibility: silently consume args removed in v1.0.23, warn on unknown + deprecated_args = {'batchwise', 'fast'} + for k in kwargs: + if k not in deprecated_args: + import warnings + warnings.warn(f"DropBlock2d() got unexpected keyword argument '{k}'") def forward(self, x): if not self.training or not self.drop_prob: return x - if self.fast: - return drop_block_fast_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) - else: - return drop_block_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + return drop_block_2d( + x, + drop_prob=self.drop_prob, + block_size=self.block_size, + gamma_scale=self.gamma_scale, + with_noise=self.with_noise, + inplace=self.inplace, + couple_channels=self.couple_channels, + scale_by_keep=self.scale_by_keep, + ) def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):