Skip to content

Commit 480a2b2

Browse files
committed
Remove SlotPool, was expensive and hard to work with. Tweak flatten for simpool2d and lse
1 parent 31f6651 commit 480a2b2

File tree

4 files changed

+17
-353
lines changed

4 files changed

+17
-353
lines changed

tests/test_layers_pool.py

Lines changed: 14 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,15 @@ def test_lse_plus_2d_basic(self):
148148
x = torch.randn(2, 64, 7, 7, device=torch_device)
149149
pool = LsePlus2d().to(torch_device)
150150
out = pool(x)
151-
assert out.shape == (2, 64, 1, 1)
151+
# Default is flatten=True
152+
assert out.shape == (2, 64)
152153

153-
def test_lse_plus_2d_flatten(self):
154+
def test_lse_plus_2d_no_flatten(self):
154155
from timm.layers import LsePlus2d
155156
x = torch.randn(2, 64, 7, 7, device=torch_device)
156-
pool = LsePlus2d(flatten=True).to(torch_device)
157+
pool = LsePlus2d(flatten=False).to(torch_device)
157158
out = pool(x)
158-
assert out.shape == (2, 64)
159+
assert out.shape == (2, 64, 1, 1)
159160

160161
def test_lse_plus_1d_basic(self):
161162
from timm.layers import LsePlus1d
@@ -169,15 +170,15 @@ def test_lse_high_r_approximates_max(self):
169170
x = torch.randn(2, 64, 7, 7, device=torch_device)
170171
pool = LsePlus2d(r=100.0, r_learnable=False).to(torch_device)
171172
out = pool(x)
172-
out_max = x.amax(dim=(2, 3), keepdim=True)
173+
out_max = x.amax(dim=(2, 3))
173174
assert torch.allclose(out, out_max, atol=0.1)
174175

175176
def test_lse_low_r_approximates_avg(self):
176177
from timm.layers import LsePlus2d
177178
x = torch.randn(2, 64, 7, 7, device=torch_device)
178179
pool = LsePlus2d(r=0.01, r_learnable=False).to(torch_device)
179180
out = pool(x)
180-
out_avg = x.mean(dim=(2, 3), keepdim=True)
181+
out_avg = x.mean(dim=(2, 3))
181182
assert torch.allclose(out, out_avg, atol=0.1)
182183

183184
def test_lse_learnable_r_gradient(self):
@@ -200,13 +201,6 @@ def test_simpool_2d_basic(self):
200201
x = torch.randn(2, 64, 7, 7, device=torch_device)
201202
pool = SimPool2d(dim=64).to(torch_device)
202203
out = pool(x)
203-
assert out.shape == (2, 1, 64)
204-
205-
def test_simpool_2d_flatten(self):
206-
from timm.layers import SimPool2d
207-
x = torch.randn(2, 64, 7, 7, device=torch_device)
208-
pool = SimPool2d(dim=64, flatten=True).to(torch_device)
209-
out = pool(x)
210204
assert out.shape == (2, 64)
211205

212206
def test_simpool_1d_basic(self):
@@ -220,89 +214,25 @@ def test_simpool_multi_head(self):
220214
from timm.layers import SimPool2d
221215
x = torch.randn(2, 64, 7, 7, device=torch_device)
222216
for num_heads in [1, 2, 4, 8]:
223-
pool = SimPool2d(dim=64, num_heads=num_heads, flatten=True).to(torch_device)
217+
pool = SimPool2d(dim=64, num_heads=num_heads).to(torch_device)
224218
out = pool(x)
225219
assert out.shape == (2, 64)
226220

227221
def test_simpool_with_gamma(self):
228222
from timm.layers import SimPool2d
229223
x = torch.randn(2, 64, 7, 7, device=torch_device)
230-
pool = SimPool2d(dim=64, gamma=2.0, flatten=True).to(torch_device)
224+
pool = SimPool2d(dim=64, gamma=2.0).to(torch_device)
231225
out = pool(x)
232226
assert out.shape == (2, 64)
233227
assert not torch.isnan(out).any()
234228

235229
def test_simpool_qk_norm(self):
236230
from timm.layers import SimPool2d
237231
x = torch.randn(2, 64, 7, 7, device=torch_device)
238-
pool = SimPool2d(dim=64, qk_norm=True, flatten=True).to(torch_device)
239-
out = pool(x)
240-
assert out.shape == (2, 64)
241-
242-
243-
# Slot Pool Tests
244-
245-
class TestSlotPool:
246-
"""Test Slot Attention pooling layers."""
247-
248-
def test_slot_pool_basic(self):
249-
from timm.layers import SlotPool
250-
x = torch.randn(2, 49, 64, device=torch_device)
251-
pool = SlotPool(dim=64).to(torch_device)
252-
out = pool(x)
253-
assert out.shape == (2, 64)
254-
255-
def test_slot_pool_2d_basic(self):
256-
from timm.layers import SlotPool2d
257-
x = torch.randn(2, 64, 7, 7, device=torch_device)
258-
pool = SlotPool2d(dim=64).to(torch_device)
232+
pool = SimPool2d(dim=64, qk_norm=True).to(torch_device)
259233
out = pool(x)
260234
assert out.shape == (2, 64)
261235

262-
def test_slot_pool_multi_slot(self):
263-
from timm.layers import SlotPool
264-
x = torch.randn(2, 49, 64, device=torch_device)
265-
for num_slots in [1, 2, 4, 8]:
266-
pool = SlotPool(dim=64, num_slots=num_slots).to(torch_device)
267-
out = pool(x)
268-
assert out.shape == (2, 64)
269-
270-
def test_slot_pool_iterations(self):
271-
from timm.layers import SlotPool
272-
x = torch.randn(2, 49, 64, device=torch_device)
273-
for iters in [1, 2, 3, 5]:
274-
pool = SlotPool(dim=64, iters=iters).to(torch_device)
275-
out = pool(x)
276-
assert out.shape == (2, 64)
277-
278-
def test_slot_pool_pool_types(self):
279-
from timm.layers import SlotPool
280-
x = torch.randn(2, 49, 64, device=torch_device)
281-
for pool_type in ['max', 'avg', 'first']:
282-
pool = SlotPool(dim=64, num_slots=4, pool_type=pool_type).to(torch_device)
283-
out = pool(x)
284-
assert out.shape == (2, 64)
285-
286-
def test_slot_pool_stochastic_train_mode(self):
287-
from timm.layers import SlotPool
288-
x = torch.randn(2, 49, 64, device=torch_device)
289-
pool = SlotPool(dim=64, stochastic_init=True).to(torch_device)
290-
pool.train()
291-
out1 = pool(x)
292-
out2 = pool(x)
293-
# Should differ in train mode with stochastic init
294-
assert not torch.allclose(out1, out2)
295-
296-
def test_slot_pool_stochastic_eval_mode(self):
297-
from timm.layers import SlotPool
298-
x = torch.randn(2, 49, 64, device=torch_device)
299-
pool = SlotPool(dim=64, stochastic_init=True).to(torch_device)
300-
pool.eval()
301-
out1 = pool(x)
302-
out2 = pool(x)
303-
# Should be deterministic in eval mode
304-
assert torch.allclose(out1, out2)
305-
306236

307237
# Common Tests (Gradient, JIT, dtype)
308238

@@ -314,8 +244,6 @@ class TestPoolingCommon:
314244
('LsePlus1d', {}, (2, 49, 64)),
315245
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
316246
('SimPool1d', {'dim': 64}, (2, 49, 64)),
317-
('SlotPool', {'dim': 64}, (2, 49, 64)),
318-
('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)),
319247
('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)),
320248
('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)),
321249
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
@@ -336,8 +264,6 @@ def test_gradient_flow(self, pool_cls, kwargs, input_shape):
336264
('LsePlus1d', {}, (2, 49, 64)),
337265
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
338266
('SimPool1d', {'dim': 64}, (2, 49, 64)),
339-
('SlotPool', {'dim': 64, 'iters': 2}, (2, 49, 64)),
340-
('SlotPool2d', {'dim': 64, 'iters': 2}, (2, 64, 7, 7)),
341267
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
342268
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
343269
])
@@ -352,12 +278,10 @@ def test_torchscript(self, pool_cls, kwargs, input_shape):
352278
assert torch.allclose(out_orig, out_script, atol=1e-5)
353279

354280
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
355-
('LsePlus2d', {'flatten': True}, (2, 64, 7, 7)),
281+
('LsePlus2d', {}, (2, 64, 7, 7)),
356282
('LsePlus1d', {}, (2, 49, 64)),
357-
('SimPool2d', {'dim': 64, 'flatten': True}, (2, 64, 7, 7)),
283+
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
358284
('SimPool1d', {'dim': 64}, (2, 49, 64)),
359-
('SlotPool', {'dim': 64}, (2, 49, 64)),
360-
('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)),
361285
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
362286
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
363287
])
@@ -372,9 +296,8 @@ def test_eval_deterministic(self, pool_cls, kwargs, input_shape):
372296
assert torch.allclose(out1, out2)
373297

374298
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
375-
('LsePlus2d', {'flatten': True}, (2, 64, 7, 7)),
376-
('SimPool2d', {'dim': 64, 'flatten': True}, (2, 64, 7, 7)),
377-
('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)),
299+
('LsePlus2d', {}, (2, 64, 7, 7)),
300+
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
378301
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
379302
])
380303
def test_different_spatial_sizes(self, pool_cls, kwargs, input_shape):

timm/layers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@
109109
from .pool1d import global_pool_nlc
110110
from .other_pool import LsePlus2d, LsePlus1d, SimPool2d, SimPool1d
111111
from .pool2d_same import AvgPool2dSame, create_pool2d
112-
from .slot_pool import SlotPool, SlotPool2d
113112
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
114113
from .pos_embed_rel import (
115114
RelPosMlp,

timm/layers/other_pool.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
self,
3838
r: float = 10.0,
3939
r_learnable: bool = True,
40-
flatten: bool = False,
40+
flatten: bool = True,
4141
device=None,
4242
dtype=None,
4343
):
@@ -118,7 +118,6 @@ def __init__(
118118
qk_norm: bool = False,
119119
gamma: Optional[float] = None,
120120
norm_layer: Optional[Type[nn.Module]] = None,
121-
flatten: bool = False,
122121
device=None,
123122
dtype=None,
124123
):
@@ -139,7 +138,6 @@ def __init__(
139138
self.head_dim = dim // num_heads
140139
self.scale = self.head_dim ** -0.5
141140
self.gamma = gamma
142-
self.flatten = flatten
143141
self.fused_attn = use_fused_attn()
144142

145143
norm_layer = norm_layer or nn.LayerNorm
@@ -192,10 +190,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
192190
attn = attn.softmax(dim=-1)
193191
out = attn @ v
194192

195-
# (B, num_heads, 1, head_dim) -> (B, C) or (B, 1, C)
196-
out = out.transpose(1, 2).reshape(B, 1, C)
197-
if self.flatten:
198-
out = out.squeeze(1)
193+
# (B, num_heads, 1, head_dim) -> (B, C) or (B, C)
194+
out = out.transpose(1, 2).reshape(B, C)
199195
return out
200196

201197

0 commit comments

Comments
 (0)