@@ -115,6 +115,8 @@ def __init__(
115115 dilation : int = 1 ,
116116 first_dilation : Optional [int ] = None ,
117117 conv_layer : Callable = ScaledStdConv2d ,
118+ device = None ,
119+ dtype = None ,
118120 ):
119121 """Initialize DownsampleAvg.
120122
@@ -133,7 +135,7 @@ def __init__(
133135 self .pool = avg_pool_fn (2 , avg_stride , ceil_mode = True , count_include_pad = False )
134136 else :
135137 self .pool = nn .Identity ()
136- self .conv = conv_layer (in_chs , out_chs , 1 , stride = 1 )
138+ self .conv = conv_layer (in_chs , out_chs , 1 , stride = 1 , device = device , dtype = dtype )
137139
138140 def forward (self , x : torch .Tensor ) -> torch .Tensor :
139141 """Forward pass.
@@ -172,6 +174,8 @@ def __init__(
172174 act_layer : Optional [Callable ] = None ,
173175 conv_layer : Callable = ScaledStdConv2d ,
174176 drop_path_rate : float = 0. ,
177+ device = None ,
178+ dtype = None ,
175179 ):
176180 """Initialize NormFreeBlock.
177181
@@ -195,6 +199,7 @@ def __init__(
195199 conv_layer: Convolution layer type.
196200 drop_path_rate: Stochastic depth drop rate.
197201 """
202+ dd = {'device' : device , 'dtype' : dtype }
198203 super ().__init__ ()
199204 first_dilation = first_dilation or dilation
200205 out_chs = out_chs or in_chs
@@ -215,32 +220,33 @@ def __init__(
215220 dilation = dilation ,
216221 first_dilation = first_dilation ,
217222 conv_layer = conv_layer ,
223+ ** dd ,
218224 )
219225 else :
220226 self .downsample = None
221227
222228 self .act1 = act_layer ()
223- self .conv1 = conv_layer (in_chs , mid_chs , 1 )
229+ self .conv1 = conv_layer (in_chs , mid_chs , 1 , ** dd )
224230 self .act2 = act_layer (inplace = True )
225- self .conv2 = conv_layer (mid_chs , mid_chs , 3 , stride = stride , dilation = first_dilation , groups = groups )
231+ self .conv2 = conv_layer (mid_chs , mid_chs , 3 , stride = stride , dilation = first_dilation , groups = groups , ** dd )
226232 if extra_conv :
227233 self .act2b = act_layer (inplace = True )
228- self .conv2b = conv_layer (mid_chs , mid_chs , 3 , stride = 1 , dilation = dilation , groups = groups )
234+ self .conv2b = conv_layer (mid_chs , mid_chs , 3 , stride = 1 , dilation = dilation , groups = groups , ** dd )
229235 else :
230236 self .act2b = None
231237 self .conv2b = None
232238 if reg and attn_layer is not None :
233- self .attn = attn_layer (mid_chs ) # RegNet blocks apply attn btw conv2 & 3
239+ self .attn = attn_layer (mid_chs , ** dd ) # RegNet blocks apply attn btw conv2 & 3
234240 else :
235241 self .attn = None
236242 self .act3 = act_layer ()
237- self .conv3 = conv_layer (mid_chs , out_chs , 1 , gain_init = 1. if skipinit else 0. )
243+ self .conv3 = conv_layer (mid_chs , out_chs , 1 , gain_init = 1. if skipinit else 0. , ** dd )
238244 if not reg and attn_layer is not None :
239- self .attn_last = attn_layer (out_chs ) # ResNet blocks apply attn after conv3
245+ self .attn_last = attn_layer (out_chs , ** dd ) # ResNet blocks apply attn after conv3
240246 else :
241247 self .attn_last = None
242248 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0 else nn .Identity ()
243- self .skipinit_gain = nn .Parameter (torch .tensor (0. )) if skipinit else None
249+ self .skipinit_gain = nn .Parameter (torch .tensor (0. , ** dd )) if skipinit else None
244250
245251 def forward (self , x : torch .Tensor ) -> torch .Tensor :
246252 """Forward pass.
@@ -283,6 +289,8 @@ def create_stem(
283289 conv_layer : Optional [Callable ] = None ,
284290 act_layer : Optional [Callable ] = None ,
285291 preact_feature : bool = True ,
292+ device = None ,
293+ dtype = None ,
286294) -> Tuple [nn .Sequential , int , Dict [str , Any ]]:
287295 """Create stem module for NFNet models.
288296
@@ -297,6 +305,7 @@ def create_stem(
297305 Returns:
298306 Tuple of (stem_module, stem_stride, stem_feature_info).
299307 """
308+ dd = {'device' : device , 'dtype' : dtype }
300309 stem_stride = 2
301310 stem_feature = dict (num_chs = out_chs , reduction = 2 , module = 'stem.conv' )
302311 stem = OrderedDict ()
@@ -318,16 +327,16 @@ def create_stem(
318327 stem_feature = dict (num_chs = out_chs // 2 , reduction = 2 , module = 'stem.conv2' )
319328 last_idx = len (stem_chs ) - 1
320329 for i , (c , s ) in enumerate (zip (stem_chs , strides )):
321- stem [f'conv{ i + 1 } ' ] = conv_layer (in_chs , c , kernel_size = 3 , stride = s )
330+ stem [f'conv{ i + 1 } ' ] = conv_layer (in_chs , c , kernel_size = 3 , stride = s , ** dd )
322331 if i != last_idx :
323332 stem [f'act{ i + 2 } ' ] = act_layer (inplace = True )
324333 in_chs = c
325334 elif '3x3' in stem_type :
326335 # 3x3 stem conv as in RegNet
327- stem ['conv' ] = conv_layer (in_chs , out_chs , kernel_size = 3 , stride = 2 )
336+ stem ['conv' ] = conv_layer (in_chs , out_chs , kernel_size = 3 , stride = 2 , ** dd )
328337 else :
329338 # 7x7 stem conv as in ResNet
330- stem ['conv' ] = conv_layer (in_chs , out_chs , kernel_size = 7 , stride = 2 )
339+ stem ['conv' ] = conv_layer (in_chs , out_chs , kernel_size = 7 , stride = 2 , ** dd )
331340
332341 if 'pool' in stem_type :
333342 stem ['pool' ] = nn .MaxPool2d (3 , stride = 2 , padding = 1 )
@@ -387,6 +396,8 @@ def __init__(
387396 output_stride : int = 32 ,
388397 drop_rate : float = 0. ,
389398 drop_path_rate : float = 0. ,
399+ device = None ,
400+ dtype = None ,
390401 ** kwargs : Any ,
391402 ):
392403 """
@@ -401,6 +412,7 @@ def __init__(
401412 **kwargs: Extra kwargs overlayed onto cfg.
402413 """
403414 super ().__init__ ()
415+ dd = {'device' : device , 'dtype' : dtype }
404416 self .num_classes = num_classes
405417 self .drop_rate = drop_rate
406418 self .grad_checkpointing = False
@@ -423,6 +435,7 @@ def __init__(
423435 cfg .stem_type ,
424436 conv_layer = conv_layer ,
425437 act_layer = act_layer ,
438+ ** dd ,
426439 )
427440
428441 self .feature_info = [stem_feat ]
@@ -462,6 +475,7 @@ def __init__(
462475 act_layer = act_layer ,
463476 conv_layer = conv_layer ,
464477 drop_path_rate = drop_path_rates [stage_idx ][block_idx ],
478+ ** dd ,
465479 )]
466480 if block_idx == 0 :
467481 expected_var = 1. # expected var is reset after first block of each stage
@@ -475,7 +489,7 @@ def __init__(
475489 if cfg .num_features :
476490 # The paper NFRegNet models have an EfficientNet-like final head convolution.
477491 self .num_features = make_divisible (cfg .width_factor * cfg .num_features , cfg .ch_div )
478- self .final_conv = conv_layer (prev_chs , self .num_features , 1 )
492+ self .final_conv = conv_layer (prev_chs , self .num_features , 1 , ** dd )
479493 self .feature_info [- 1 ] = dict (num_chs = self .num_features , reduction = net_stride , module = f'final_conv' )
480494 else :
481495 self .num_features = prev_chs
@@ -488,6 +502,7 @@ def __init__(
488502 num_classes ,
489503 pool_type = global_pool ,
490504 drop_rate = self .drop_rate ,
505+ ** dd ,
491506 )
492507
493508 for n , m in self .named_modules ():
0 commit comments