@@ -148,14 +148,15 @@ def test_lse_plus_2d_basic(self):
148148 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
149149 pool = LsePlus2d ().to (torch_device )
150150 out = pool (x )
151- assert out .shape == (2 , 64 , 1 , 1 )
151+ # Default is flatten=True
152+ assert out .shape == (2 , 64 )
152153
153- def test_lse_plus_2d_flatten (self ):
154+ def test_lse_plus_2d_no_flatten (self ):
154155 from timm .layers import LsePlus2d
155156 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
156- pool = LsePlus2d (flatten = True ).to (torch_device )
157+ pool = LsePlus2d (flatten = False ).to (torch_device )
157158 out = pool (x )
158- assert out .shape == (2 , 64 )
159+ assert out .shape == (2 , 64 , 1 , 1 )
159160
160161 def test_lse_plus_1d_basic (self ):
161162 from timm .layers import LsePlus1d
@@ -169,15 +170,15 @@ def test_lse_high_r_approximates_max(self):
169170 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
170171 pool = LsePlus2d (r = 100.0 , r_learnable = False ).to (torch_device )
171172 out = pool (x )
172- out_max = x .amax (dim = (2 , 3 ), keepdim = True )
173+ out_max = x .amax (dim = (2 , 3 ))
173174 assert torch .allclose (out , out_max , atol = 0.1 )
174175
175176 def test_lse_low_r_approximates_avg (self ):
176177 from timm .layers import LsePlus2d
177178 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
178179 pool = LsePlus2d (r = 0.01 , r_learnable = False ).to (torch_device )
179180 out = pool (x )
180- out_avg = x .mean (dim = (2 , 3 ), keepdim = True )
181+ out_avg = x .mean (dim = (2 , 3 ))
181182 assert torch .allclose (out , out_avg , atol = 0.1 )
182183
183184 def test_lse_learnable_r_gradient (self ):
@@ -200,13 +201,6 @@ def test_simpool_2d_basic(self):
200201 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
201202 pool = SimPool2d (dim = 64 ).to (torch_device )
202203 out = pool (x )
203- assert out .shape == (2 , 1 , 64 )
204-
205- def test_simpool_2d_flatten (self ):
206- from timm .layers import SimPool2d
207- x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
208- pool = SimPool2d (dim = 64 , flatten = True ).to (torch_device )
209- out = pool (x )
210204 assert out .shape == (2 , 64 )
211205
212206 def test_simpool_1d_basic (self ):
@@ -220,89 +214,25 @@ def test_simpool_multi_head(self):
220214 from timm .layers import SimPool2d
221215 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
222216 for num_heads in [1 , 2 , 4 , 8 ]:
223- pool = SimPool2d (dim = 64 , num_heads = num_heads , flatten = True ).to (torch_device )
217+ pool = SimPool2d (dim = 64 , num_heads = num_heads ).to (torch_device )
224218 out = pool (x )
225219 assert out .shape == (2 , 64 )
226220
227221 def test_simpool_with_gamma (self ):
228222 from timm .layers import SimPool2d
229223 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
230- pool = SimPool2d (dim = 64 , gamma = 2.0 , flatten = True ).to (torch_device )
224+ pool = SimPool2d (dim = 64 , gamma = 2.0 ).to (torch_device )
231225 out = pool (x )
232226 assert out .shape == (2 , 64 )
233227 assert not torch .isnan (out ).any ()
234228
235229 def test_simpool_qk_norm (self ):
236230 from timm .layers import SimPool2d
237231 x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
238- pool = SimPool2d (dim = 64 , qk_norm = True , flatten = True ).to (torch_device )
239- out = pool (x )
240- assert out .shape == (2 , 64 )
241-
242-
243- # Slot Pool Tests
244-
245- class TestSlotPool :
246- """Test Slot Attention pooling layers."""
247-
248- def test_slot_pool_basic (self ):
249- from timm .layers import SlotPool
250- x = torch .randn (2 , 49 , 64 , device = torch_device )
251- pool = SlotPool (dim = 64 ).to (torch_device )
252- out = pool (x )
253- assert out .shape == (2 , 64 )
254-
255- def test_slot_pool_2d_basic (self ):
256- from timm .layers import SlotPool2d
257- x = torch .randn (2 , 64 , 7 , 7 , device = torch_device )
258- pool = SlotPool2d (dim = 64 ).to (torch_device )
232+ pool = SimPool2d (dim = 64 , qk_norm = True ).to (torch_device )
259233 out = pool (x )
260234 assert out .shape == (2 , 64 )
261235
262- def test_slot_pool_multi_slot (self ):
263- from timm .layers import SlotPool
264- x = torch .randn (2 , 49 , 64 , device = torch_device )
265- for num_slots in [1 , 2 , 4 , 8 ]:
266- pool = SlotPool (dim = 64 , num_slots = num_slots ).to (torch_device )
267- out = pool (x )
268- assert out .shape == (2 , 64 )
269-
270- def test_slot_pool_iterations (self ):
271- from timm .layers import SlotPool
272- x = torch .randn (2 , 49 , 64 , device = torch_device )
273- for iters in [1 , 2 , 3 , 5 ]:
274- pool = SlotPool (dim = 64 , iters = iters ).to (torch_device )
275- out = pool (x )
276- assert out .shape == (2 , 64 )
277-
278- def test_slot_pool_pool_types (self ):
279- from timm .layers import SlotPool
280- x = torch .randn (2 , 49 , 64 , device = torch_device )
281- for pool_type in ['max' , 'avg' , 'first' ]:
282- pool = SlotPool (dim = 64 , num_slots = 4 , pool_type = pool_type ).to (torch_device )
283- out = pool (x )
284- assert out .shape == (2 , 64 )
285-
286- def test_slot_pool_stochastic_train_mode (self ):
287- from timm .layers import SlotPool
288- x = torch .randn (2 , 49 , 64 , device = torch_device )
289- pool = SlotPool (dim = 64 , stochastic_init = True ).to (torch_device )
290- pool .train ()
291- out1 = pool (x )
292- out2 = pool (x )
293- # Should differ in train mode with stochastic init
294- assert not torch .allclose (out1 , out2 )
295-
296- def test_slot_pool_stochastic_eval_mode (self ):
297- from timm .layers import SlotPool
298- x = torch .randn (2 , 49 , 64 , device = torch_device )
299- pool = SlotPool (dim = 64 , stochastic_init = True ).to (torch_device )
300- pool .eval ()
301- out1 = pool (x )
302- out2 = pool (x )
303- # Should be deterministic in eval mode
304- assert torch .allclose (out1 , out2 )
305-
306236
307237# Common Tests (Gradient, JIT, dtype)
308238
@@ -314,8 +244,6 @@ class TestPoolingCommon:
314244 ('LsePlus1d' , {}, (2 , 49 , 64 )),
315245 ('SimPool2d' , {'dim' : 64 }, (2 , 64 , 7 , 7 )),
316246 ('SimPool1d' , {'dim' : 64 }, (2 , 49 , 64 )),
317- ('SlotPool' , {'dim' : 64 }, (2 , 49 , 64 )),
318- ('SlotPool2d' , {'dim' : 64 }, (2 , 64 , 7 , 7 )),
319247 ('SelectAdaptivePool2d' , {'pool_type' : 'avg' , 'flatten' : True }, (2 , 64 , 7 , 7 )),
320248 ('AttentionPoolLatent' , {'in_features' : 64 , 'num_heads' : 4 }, (2 , 49 , 64 )),
321249 ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 }, (2 , 64 , 7 , 7 )),
@@ -336,8 +264,6 @@ def test_gradient_flow(self, pool_cls, kwargs, input_shape):
336264 ('LsePlus1d' , {}, (2 , 49 , 64 )),
337265 ('SimPool2d' , {'dim' : 64 }, (2 , 64 , 7 , 7 )),
338266 ('SimPool1d' , {'dim' : 64 }, (2 , 49 , 64 )),
339- ('SlotPool' , {'dim' : 64 , 'iters' : 2 }, (2 , 49 , 64 )),
340- ('SlotPool2d' , {'dim' : 64 , 'iters' : 2 }, (2 , 64 , 7 , 7 )),
341267 ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 }, (2 , 64 , 7 , 7 )),
342268 ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 }, (2 , 64 , 7 , 7 )),
343269 ])
@@ -352,12 +278,10 @@ def test_torchscript(self, pool_cls, kwargs, input_shape):
352278 assert torch .allclose (out_orig , out_script , atol = 1e-5 )
353279
354280 @pytest .mark .parametrize ('pool_cls,kwargs,input_shape' , [
355- ('LsePlus2d' , {'flatten' : True }, (2 , 64 , 7 , 7 )),
281+ ('LsePlus2d' , {}, (2 , 64 , 7 , 7 )),
356282 ('LsePlus1d' , {}, (2 , 49 , 64 )),
357- ('SimPool2d' , {'dim' : 64 , 'flatten' : True }, (2 , 64 , 7 , 7 )),
283+ ('SimPool2d' , {'dim' : 64 }, (2 , 64 , 7 , 7 )),
358284 ('SimPool1d' , {'dim' : 64 }, (2 , 49 , 64 )),
359- ('SlotPool' , {'dim' : 64 }, (2 , 49 , 64 )),
360- ('SlotPool2d' , {'dim' : 64 }, (2 , 64 , 7 , 7 )),
361285 ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 }, (2 , 64 , 7 , 7 )),
362286 ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 }, (2 , 64 , 7 , 7 )),
363287 ])
@@ -372,9 +296,8 @@ def test_eval_deterministic(self, pool_cls, kwargs, input_shape):
372296 assert torch .allclose (out1 , out2 )
373297
374298 @pytest .mark .parametrize ('pool_cls,kwargs,input_shape' , [
375- ('LsePlus2d' , {'flatten' : True }, (2 , 64 , 7 , 7 )),
376- ('SimPool2d' , {'dim' : 64 , 'flatten' : True }, (2 , 64 , 7 , 7 )),
377- ('SlotPool2d' , {'dim' : 64 }, (2 , 64 , 7 , 7 )),
299+ ('LsePlus2d' , {}, (2 , 64 , 7 , 7 )),
300+ ('SimPool2d' , {'dim' : 64 }, (2 , 64 , 7 , 7 )),
378301 ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 }, (2 , 64 , 7 , 7 )),
379302 ])
380303 def test_different_spatial_sizes (self , pool_cls , kwargs , input_shape ):
0 commit comments