Skip to content

Commit f6200b5

Browse files
committed
A simplification / fix for DropBlock2d. Handles even kernel sizes properly thanks to https://github.com/crutcher. For simplicty don't bother keeping drop blocks within feature maps. Remove 'batchwise' option. Remove separate 'fast' option.
1 parent 33ec6d7 commit f6200b5

File tree

2 files changed

+339
-82
lines changed

2 files changed

+339
-82
lines changed

tests/test_layers_drop.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""Tests for timm.layers.drop module (DropBlock, DropPath)."""
2+
import torch
3+
import pytest
4+
5+
from timm.layers.drop import drop_block_2d, DropBlock2d, drop_path, DropPath
6+
7+
8+
class TestDropBlock2d:
9+
"""Test drop_block_2d function and DropBlock2d module."""
10+
11+
def test_drop_block_2d_output_shape(self):
12+
"""Test that output shape matches input shape."""
13+
for h, w in [(7, 7), (4, 8), (10, 5), (3, 3)]:
14+
x = torch.ones((2, 3, h, w))
15+
result = drop_block_2d(x, drop_prob=0.1, block_size=3)
16+
assert result.shape == x.shape, f"Shape mismatch for input ({h}, {w})"
17+
18+
def test_drop_block_2d_no_drop_when_prob_zero(self):
19+
"""Test that no dropping occurs when drop_prob=0."""
20+
x = torch.ones((2, 3, 8, 8))
21+
result = drop_block_2d(x, drop_prob=0.0, block_size=3)
22+
assert torch.allclose(result, x)
23+
24+
def test_drop_block_2d_approximate_keep_ratio(self):
25+
"""Test that the drop ratio is approximately correct."""
26+
torch.manual_seed(123)
27+
# Use large batch for statistical stability
28+
x = torch.ones((32, 16, 56, 56))
29+
drop_prob = 0.1
30+
31+
# With scale_by_keep=False, kept values stay at 1.0 and dropped are 0.0
32+
# so we can directly measure the drop ratio
33+
result = drop_block_2d(x, drop_prob=drop_prob, block_size=7, scale_by_keep=False)
34+
35+
total_elements = result.numel()
36+
dropped_elements = (result == 0).sum().item()
37+
actual_drop_ratio = dropped_elements / total_elements
38+
39+
# Allow some tolerance since it's stochastic
40+
assert abs(actual_drop_ratio - drop_prob) < 0.03, \
41+
f"Drop ratio {actual_drop_ratio:.3f} not close to expected {drop_prob}"
42+
43+
def test_drop_block_2d_inplace(self):
44+
"""Test inplace operation."""
45+
x = torch.ones((2, 3, 8, 8))
46+
x_clone = x.clone()
47+
torch.manual_seed(42)
48+
result = drop_block_2d(x_clone, drop_prob=0.3, block_size=3, inplace=True)
49+
assert result is x_clone, "Inplace should return the same tensor"
50+
51+
def test_drop_block_2d_couple_channels_true(self):
52+
"""Test couple_channels=True uses same mask for all channels."""
53+
torch.manual_seed(42)
54+
x = torch.ones((2, 4, 16, 16))
55+
result = drop_block_2d(x, drop_prob=0.3, block_size=5, couple_channels=True)
56+
57+
# With couple_channels=True, all channels should have same drop pattern
58+
for b in range(x.shape[0]):
59+
mask_c0 = (result[b, 0] == 0).float()
60+
for c in range(1, x.shape[1]):
61+
mask_c = (result[b, c] == 0).float()
62+
assert torch.allclose(mask_c0, mask_c), f"Channel {c} has different mask than channel 0"
63+
64+
def test_drop_block_2d_couple_channels_false(self):
65+
"""Test couple_channels=False uses independent mask per channel."""
66+
torch.manual_seed(42)
67+
x = torch.ones((2, 4, 16, 16))
68+
result = drop_block_2d(x, drop_prob=0.3, block_size=5, couple_channels=False)
69+
70+
# With couple_channels=False, channels should have different patterns
71+
# (with high probability for reasonable drop_prob)
72+
mask_c0 = (result[0, 0] == 0).float()
73+
mask_c1 = (result[0, 1] == 0).float()
74+
# They might occasionally be the same by chance, but very unlikely
75+
assert not torch.allclose(mask_c0, mask_c1), "Channels should have independent masks"
76+
77+
def test_drop_block_2d_with_noise(self):
78+
"""Test with_noise option adds gaussian noise to dropped regions."""
79+
torch.manual_seed(42)
80+
x = torch.ones((2, 3, 16, 16))
81+
result = drop_block_2d(x, drop_prob=0.3, block_size=5, with_noise=True)
82+
83+
# With noise, dropped regions should have non-zero values from gaussian noise
84+
# The result should contain values other than the scaled kept values
85+
unique_vals = torch.unique(result)
86+
assert len(unique_vals) > 2, "With noise should produce varied values"
87+
88+
def test_drop_block_2d_even_block_size(self):
89+
"""Test that even block sizes work correctly."""
90+
x = torch.ones((2, 3, 16, 16))
91+
for block_size in [2, 4, 6]:
92+
result = drop_block_2d(x, drop_prob=0.1, block_size=block_size)
93+
assert result.shape == x.shape, f"Shape mismatch for block_size={block_size}"
94+
95+
def test_drop_block_2d_asymmetric_input(self):
96+
"""Test with asymmetric H != W inputs."""
97+
for h, w in [(8, 16), (16, 8), (7, 14), (14, 7)]:
98+
x = torch.ones((2, 3, h, w))
99+
result = drop_block_2d(x, drop_prob=0.1, block_size=5)
100+
assert result.shape == x.shape, f"Shape mismatch for ({h}, {w})"
101+
102+
def test_drop_block_2d_scale_by_keep(self):
103+
"""Test scale_by_keep parameter."""
104+
torch.manual_seed(42)
105+
x = torch.ones((2, 3, 16, 16))
106+
107+
# With scale_by_keep=True (default), kept values are scaled up
108+
result_scaled = drop_block_2d(x.clone(), drop_prob=0.3, block_size=5, scale_by_keep=True)
109+
kept_vals_scaled = result_scaled[result_scaled > 0]
110+
# Scaled values should be > 1.0 (scaled up to compensate for drops)
111+
assert kept_vals_scaled.min() > 1.0, "Scaled values should be > 1.0"
112+
113+
# With scale_by_keep=False, kept values stay at original
114+
torch.manual_seed(42)
115+
result_unscaled = drop_block_2d(x.clone(), drop_prob=0.3, block_size=5, scale_by_keep=False)
116+
kept_vals_unscaled = result_unscaled[result_unscaled > 0]
117+
# Unscaled values should be exactly 1.0
118+
assert torch.allclose(kept_vals_unscaled, torch.ones_like(kept_vals_unscaled)), \
119+
"Unscaled values should be 1.0"
120+
121+
122+
class TestDropBlock2dModule:
123+
"""Test DropBlock2d nn.Module."""
124+
125+
def test_deprecated_args_accepted(self):
126+
"""Test that deprecated args (batchwise, fast) are silently accepted."""
127+
# These should not raise
128+
module1 = DropBlock2d(drop_prob=0.1, batchwise=True)
129+
module2 = DropBlock2d(drop_prob=0.1, fast=False)
130+
module3 = DropBlock2d(drop_prob=0.1, batchwise=False, fast=True)
131+
assert module1.drop_prob == 0.1
132+
assert module2.drop_prob == 0.1
133+
assert module3.drop_prob == 0.1
134+
135+
def test_unknown_args_warned(self):
136+
"""Test that unknown kwargs emit a warning."""
137+
with pytest.warns(UserWarning, match="unexpected keyword argument 'unknown_arg'"):
138+
DropBlock2d(drop_prob=0.1, unknown_arg=True)
139+
140+
def test_training_mode(self):
141+
"""Test that dropping only occurs in training mode."""
142+
module = DropBlock2d(drop_prob=0.5, block_size=3)
143+
x = torch.ones((2, 3, 8, 8))
144+
145+
# In eval mode, should return input unchanged
146+
module.eval()
147+
result = module(x)
148+
assert torch.allclose(result, x), "Should not drop in eval mode"
149+
150+
# In train mode, should modify input
151+
module.train()
152+
torch.manual_seed(42)
153+
result = module(x)
154+
assert not torch.allclose(result, x), "Should drop in train mode"
155+
156+
def test_couple_channels_parameter(self):
157+
"""Test couple_channels parameter is passed through."""
158+
x = torch.ones((2, 4, 16, 16))
159+
160+
# couple_channels=True (default)
161+
module_coupled = DropBlock2d(drop_prob=0.3, block_size=5, couple_channels=True)
162+
module_coupled.train()
163+
torch.manual_seed(42)
164+
result_coupled = module_coupled(x)
165+
166+
# All channels should have same pattern
167+
mask_c0 = (result_coupled[0, 0] == 0).float()
168+
mask_c1 = (result_coupled[0, 1] == 0).float()
169+
assert torch.allclose(mask_c0, mask_c1)
170+
171+
# couple_channels=False
172+
module_uncoupled = DropBlock2d(drop_prob=0.3, block_size=5, couple_channels=False)
173+
module_uncoupled.train()
174+
torch.manual_seed(42)
175+
result_uncoupled = module_uncoupled(x)
176+
177+
# Channels should have different patterns
178+
mask_c0 = (result_uncoupled[0, 0] == 0).float()
179+
mask_c1 = (result_uncoupled[0, 1] == 0).float()
180+
assert not torch.allclose(mask_c0, mask_c1)
181+
182+
183+
class TestDropPath:
184+
"""Test drop_path function and DropPath module."""
185+
186+
def test_no_drop_when_prob_zero(self):
187+
"""Test that no dropping occurs when drop_prob=0."""
188+
x = torch.ones((4, 8, 16, 16))
189+
result = drop_path(x, drop_prob=0.0, training=True)
190+
assert torch.allclose(result, x)
191+
192+
def test_no_drop_when_not_training(self):
193+
"""Test that no dropping occurs when not training."""
194+
x = torch.ones((4, 8, 16, 16))
195+
result = drop_path(x, drop_prob=0.5, training=False)
196+
assert torch.allclose(result, x)
197+
198+
def test_drop_path_scaling(self):
199+
"""Test that scale_by_keep properly scales kept paths."""
200+
torch.manual_seed(42)
201+
x = torch.ones((100, 8, 4, 4)) # Large batch for statistical stability
202+
keep_prob = 0.8
203+
drop_prob = 1 - keep_prob
204+
205+
result = drop_path(x, drop_prob=drop_prob, training=True, scale_by_keep=True)
206+
207+
# Kept samples should be scaled by 1/keep_prob = 1.25
208+
kept_mask = (result[:, 0, 0, 0] != 0)
209+
if kept_mask.any():
210+
kept_vals = result[kept_mask, 0, 0, 0]
211+
expected_scale = 1.0 / keep_prob
212+
assert torch.allclose(kept_vals, torch.full_like(kept_vals, expected_scale), atol=1e-5)
213+
214+
def test_drop_path_no_scaling(self):
215+
"""Test that scale_by_keep=False does not scale."""
216+
torch.manual_seed(42)
217+
x = torch.ones((100, 8, 4, 4))
218+
result = drop_path(x, drop_prob=0.2, training=True, scale_by_keep=False)
219+
220+
# Kept samples should remain at 1.0
221+
kept_mask = (result[:, 0, 0, 0] != 0)
222+
if kept_mask.any():
223+
kept_vals = result[kept_mask, 0, 0, 0]
224+
assert torch.allclose(kept_vals, torch.ones_like(kept_vals))
225+
226+
227+
class TestDropPathModule:
228+
"""Test DropPath nn.Module."""
229+
230+
def test_training_mode(self):
231+
"""Test that dropping only occurs in training mode."""
232+
module = DropPath(drop_prob=0.5)
233+
x = torch.ones((32, 8, 4, 4)) # Larger batch for statistical reliability
234+
235+
module.eval()
236+
result = module(x)
237+
assert torch.allclose(result, x), "Should not drop in eval mode"
238+
239+
module.train()
240+
torch.manual_seed(42)
241+
result = module(x)
242+
# With 50% drop prob on 32 samples, very unlikely all survive
243+
# Check that at least one sample has zeros (was dropped)
244+
has_zeros = (result == 0).any()
245+
assert has_zeros, "Should drop some paths in train mode"
246+
247+
def test_extra_repr(self):
248+
"""Test extra_repr for nice printing."""
249+
module = DropPath(drop_prob=0.123)
250+
repr_str = module.extra_repr()
251+
assert "0.123" in repr_str

0 commit comments

Comments
 (0)