Skip to content

Commit ee751ef

Browse files
committed
Add dd factory kwargs to nfnet and resnetv2
1 parent b94c221 commit ee751ef

File tree

2 files changed

+83
-41
lines changed

2 files changed

+83
-41
lines changed

timm/models/nfnet.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def __init__(
115115
dilation: int = 1,
116116
first_dilation: Optional[int] = None,
117117
conv_layer: Callable = ScaledStdConv2d,
118+
device=None,
119+
dtype=None,
118120
):
119121
"""Initialize DownsampleAvg.
120122
@@ -133,7 +135,7 @@ def __init__(
133135
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
134136
else:
135137
self.pool = nn.Identity()
136-
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
138+
self.conv = conv_layer(in_chs, out_chs, 1, stride=1, device=device, dtype=dtype)
137139

138140
def forward(self, x: torch.Tensor) -> torch.Tensor:
139141
"""Forward pass.
@@ -172,6 +174,8 @@ def __init__(
172174
act_layer: Optional[Callable] = None,
173175
conv_layer: Callable = ScaledStdConv2d,
174176
drop_path_rate: float = 0.,
177+
device=None,
178+
dtype=None,
175179
):
176180
"""Initialize NormFreeBlock.
177181
@@ -195,6 +199,7 @@ def __init__(
195199
conv_layer: Convolution layer type.
196200
drop_path_rate: Stochastic depth drop rate.
197201
"""
202+
dd = {'device': device, 'dtype': dtype}
198203
super().__init__()
199204
first_dilation = first_dilation or dilation
200205
out_chs = out_chs or in_chs
@@ -215,32 +220,33 @@ def __init__(
215220
dilation=dilation,
216221
first_dilation=first_dilation,
217222
conv_layer=conv_layer,
223+
**dd,
218224
)
219225
else:
220226
self.downsample = None
221227

222228
self.act1 = act_layer()
223-
self.conv1 = conv_layer(in_chs, mid_chs, 1)
229+
self.conv1 = conv_layer(in_chs, mid_chs, 1, **dd)
224230
self.act2 = act_layer(inplace=True)
225-
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
231+
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups, **dd)
226232
if extra_conv:
227233
self.act2b = act_layer(inplace=True)
228-
self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups)
234+
self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups, **dd)
229235
else:
230236
self.act2b = None
231237
self.conv2b = None
232238
if reg and attn_layer is not None:
233-
self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3
239+
self.attn = attn_layer(mid_chs, **dd) # RegNet blocks apply attn btw conv2 & 3
234240
else:
235241
self.attn = None
236242
self.act3 = act_layer()
237-
self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.)
243+
self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0., **dd)
238244
if not reg and attn_layer is not None:
239-
self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3
245+
self.attn_last = attn_layer(out_chs, **dd) # ResNet blocks apply attn after conv3
240246
else:
241247
self.attn_last = None
242248
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
243-
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
249+
self.skipinit_gain = nn.Parameter(torch.tensor(0., **dd)) if skipinit else None
244250

245251
def forward(self, x: torch.Tensor) -> torch.Tensor:
246252
"""Forward pass.
@@ -283,6 +289,8 @@ def create_stem(
283289
conv_layer: Optional[Callable] = None,
284290
act_layer: Optional[Callable] = None,
285291
preact_feature: bool = True,
292+
device=None,
293+
dtype=None,
286294
) -> Tuple[nn.Sequential, int, Dict[str, Any]]:
287295
"""Create stem module for NFNet models.
288296
@@ -297,6 +305,7 @@ def create_stem(
297305
Returns:
298306
Tuple of (stem_module, stem_stride, stem_feature_info).
299307
"""
308+
dd = {'device': device, 'dtype': dtype}
300309
stem_stride = 2
301310
stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv')
302311
stem = OrderedDict()
@@ -318,16 +327,16 @@ def create_stem(
318327
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2')
319328
last_idx = len(stem_chs) - 1
320329
for i, (c, s) in enumerate(zip(stem_chs, strides)):
321-
stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
330+
stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s, **dd)
322331
if i != last_idx:
323332
stem[f'act{i + 2}'] = act_layer(inplace=True)
324333
in_chs = c
325334
elif '3x3' in stem_type:
326335
# 3x3 stem conv as in RegNet
327-
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
336+
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2, **dd)
328337
else:
329338
# 7x7 stem conv as in ResNet
330-
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
339+
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2, **dd)
331340

332341
if 'pool' in stem_type:
333342
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
@@ -387,6 +396,8 @@ def __init__(
387396
output_stride: int = 32,
388397
drop_rate: float = 0.,
389398
drop_path_rate: float = 0.,
399+
device=None,
400+
dtype=None,
390401
**kwargs: Any,
391402
):
392403
"""
@@ -401,6 +412,7 @@ def __init__(
401412
**kwargs: Extra kwargs overlayed onto cfg.
402413
"""
403414
super().__init__()
415+
dd = {'device': device, 'dtype': dtype}
404416
self.num_classes = num_classes
405417
self.drop_rate = drop_rate
406418
self.grad_checkpointing = False
@@ -423,6 +435,7 @@ def __init__(
423435
cfg.stem_type,
424436
conv_layer=conv_layer,
425437
act_layer=act_layer,
438+
**dd,
426439
)
427440

428441
self.feature_info = [stem_feat]
@@ -462,6 +475,7 @@ def __init__(
462475
act_layer=act_layer,
463476
conv_layer=conv_layer,
464477
drop_path_rate=drop_path_rates[stage_idx][block_idx],
478+
**dd,
465479
)]
466480
if block_idx == 0:
467481
expected_var = 1. # expected var is reset after first block of each stage
@@ -475,7 +489,7 @@ def __init__(
475489
if cfg.num_features:
476490
# The paper NFRegNet models have an EfficientNet-like final head convolution.
477491
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
478-
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
492+
self.final_conv = conv_layer(prev_chs, self.num_features, 1, **dd)
479493
self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv')
480494
else:
481495
self.num_features = prev_chs
@@ -488,6 +502,7 @@ def __init__(
488502
num_classes,
489503
pool_type=global_pool,
490504
drop_rate=self.drop_rate,
505+
**dd,
491506
)
492507

493508
for n, m in self.named_modules():

0 commit comments

Comments
 (0)