1717
1818try :
1919 import safetensors .torch
20+
2021 _has_safetensors = True
2122except ImportError :
2223 _has_safetensors = False
3233try :
3334 from huggingface_hub import HfApi , hf_hub_download , model_info
3435 from huggingface_hub .utils import EntryNotFoundError , RepositoryNotFoundError
36+
3537 hf_hub_download = partial (hf_hub_download , library_name = "timm" , library_version = __version__ )
3638 _has_hf_hub = True
3739except ImportError :
4042
4143_logger = logging .getLogger (__name__ )
4244
43- __all__ = ['get_cache_dir' , 'download_cached_file' , 'has_hf_hub' , 'hf_split' , 'load_model_config_from_hf' ,
44- 'load_state_dict_from_hf' , 'save_for_hf' , 'push_to_hf_hub' ]
45+ __all__ = [
46+ 'get_cache_dir' ,
47+ 'download_cached_file' ,
48+ 'has_hf_hub' ,
49+ 'hf_split' ,
50+ 'load_model_config_from_hf' ,
51+ 'load_state_dict_from_hf' ,
52+ 'save_for_hf' ,
53+ 'push_to_hf_hub' ,
54+ ]
4555
4656# Default name for a weights file hosted on the Huggingface Hub.
4757HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -66,10 +76,10 @@ def get_cache_dir(child_dir: str = ''):
6676
6777
6878def download_cached_file (
69- url : Union [str , List [str ], Tuple [str , str ]],
70- check_hash : bool = True ,
71- progress : bool = False ,
72- cache_dir : Optional [Union [str , Path ]] = None ,
79+ url : Union [str , List [str ], Tuple [str , str ]],
80+ check_hash : bool = True ,
81+ progress : bool = False ,
82+ cache_dir : Optional [Union [str , Path ]] = None ,
7383):
7484 if isinstance (url , (list , tuple )):
7585 url , filename = url
@@ -92,9 +102,9 @@ def download_cached_file(
92102
93103
94104def check_cached_file (
95- url : Union [str , List [str ], Tuple [str , str ]],
96- check_hash : bool = True ,
97- cache_dir : Optional [Union [str , Path ]] = None ,
105+ url : Union [str , List [str ], Tuple [str , str ]],
106+ check_hash : bool = True ,
107+ cache_dir : Optional [Union [str , Path ]] = None ,
98108):
99109 if isinstance (url , (list , tuple )):
100110 url , filename = url
@@ -111,7 +121,7 @@ def check_cached_file(
111121 if hash_prefix :
112122 with open (cached_file , 'rb' ) as f :
113123 hd = hashlib .sha256 (f .read ()).hexdigest ()
114- if hd [:len (hash_prefix )] != hash_prefix :
124+ if hd [: len (hash_prefix )] != hash_prefix :
115125 return False
116126 return True
117127 return False
@@ -121,7 +131,8 @@ def has_hf_hub(necessary: bool = False):
121131 if not _has_hf_hub and necessary :
122132 # if no HF Hub module installed, and it is necessary to continue, raise error
123133 raise RuntimeError (
124- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.' )
134+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
135+ )
125136 return _has_hf_hub
126137
127138
@@ -141,9 +152,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
141152
142153
143154def download_from_hf (
144- model_id : str ,
145- filename : str ,
146- cache_dir : Optional [Union [str , Path ]] = None ,
155+ model_id : str ,
156+ filename : str ,
157+ cache_dir : Optional [Union [str , Path ]] = None ,
147158):
148159 hf_model_id , hf_revision = hf_split (model_id )
149160 return hf_hub_download (
@@ -155,8 +166,8 @@ def download_from_hf(
155166
156167
157168def _parse_model_cfg (
158- cfg : Dict [str , Any ],
159- extra_fields : Dict [str , Any ],
169+ cfg : Dict [str , Any ],
170+ extra_fields : Dict [str , Any ],
160171) -> Tuple [Dict [str , Any ], str , Dict [str , Any ]]:
161172 """"""
162173 # legacy "single‑dict" → split
@@ -167,7 +178,7 @@ def _parse_model_cfg(
167178 "num_features" : pretrained_cfg .pop ("num_features" , None ),
168179 "pretrained_cfg" : pretrained_cfg ,
169180 }
170- if "labels" in pretrained_cfg : # rename ‑‑> label_names
181+ if "labels" in pretrained_cfg : # rename ‑‑> label_names
171182 pretrained_cfg ["label_names" ] = pretrained_cfg .pop ("labels" )
172183
173184 pretrained_cfg = cfg ["pretrained_cfg" ]
@@ -187,8 +198,8 @@ def _parse_model_cfg(
187198
188199
189200def load_model_config_from_hf (
190- model_id : str ,
191- cache_dir : Optional [Union [str , Path ]] = None ,
201+ model_id : str ,
202+ cache_dir : Optional [Union [str , Path ]] = None ,
192203):
193204 """Original HF‑Hub loader (unchanged download, shared parsing)."""
194205 assert has_hf_hub (True )
@@ -198,7 +209,7 @@ def load_model_config_from_hf(
198209
199210
200211def load_model_config_from_path (
201- model_path : Union [str , Path ],
212+ model_path : Union [str , Path ],
202213):
203214 """Load from ``<model_path>/config.json`` on the local filesystem."""
204215 model_path = Path (model_path )
@@ -211,10 +222,10 @@ def load_model_config_from_path(
211222
212223
213224def load_state_dict_from_hf (
214- model_id : str ,
215- filename : str = HF_WEIGHTS_NAME ,
216- weights_only : bool = False ,
217- cache_dir : Optional [Union [str , Path ]] = None ,
225+ model_id : str ,
226+ filename : str = HF_WEIGHTS_NAME ,
227+ weights_only : bool = False ,
228+ cache_dir : Optional [Union [str , Path ]] = None ,
218229):
219230 assert has_hf_hub (True )
220231 hf_model_id , hf_revision = hf_split (model_id )
@@ -231,7 +242,8 @@ def load_state_dict_from_hf(
231242 )
232243 _logger .info (
233244 f"[{ model_id } ] Safe alternative available for '{ filename } ' "
234- f"(as '{ safe_filename } '). Loading weights using safetensors." )
245+ f"(as '{ safe_filename } '). Loading weights using safetensors."
246+ )
235247 return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
236248 except EntryNotFoundError :
237249 pass
@@ -263,9 +275,10 @@ def load_state_dict_from_hf(
263275)
264276_EXT_PRIORITY = ('.safetensors' , '.pth' , '.pth.tar' , '.bin' )
265277
278+
266279def load_state_dict_from_path (
267- path : str ,
268- weights_only : bool = False ,
280+ path : str ,
281+ weights_only : bool = False ,
269282):
270283 found_file = None
271284 for fname in _PREFERRED_FILES :
@@ -280,10 +293,7 @@ def load_state_dict_from_path(
280293 files = sorted (path .glob (f"*{ ext } " ))
281294 if files :
282295 if len (files ) > 1 :
283- logging .warning (
284- f"Multiple { ext } checkpoints in { path } : { names } . "
285- f"Using '{ files [0 ].name } '."
286- )
296+ logging .warning (f"Multiple { ext } checkpoints in { path } : { names } . " f"Using '{ files [0 ].name } '." )
287297 found_file = files [0 ]
288298
289299 if not found_file :
@@ -297,10 +307,10 @@ def load_state_dict_from_path(
297307
298308
299309def load_custom_from_hf (
300- model_id : str ,
301- filename : str ,
302- model : torch .nn .Module ,
303- cache_dir : Optional [Union [str , Path ]] = None ,
310+ model_id : str ,
311+ filename : str ,
312+ model : torch .nn .Module ,
313+ cache_dir : Optional [Union [str , Path ]] = None ,
304314):
305315 assert has_hf_hub (True )
306316 hf_model_id , hf_revision = hf_split (model_id )
@@ -314,10 +324,7 @@ def load_custom_from_hf(
314324
315325
316326def save_config_for_hf (
317- model : torch .nn .Module ,
318- config_path : str ,
319- model_config : Optional [dict ] = None ,
320- model_args : Optional [dict ] = None
327+ model : torch .nn .Module , config_path : str , model_config : Optional [dict ] = None , model_args : Optional [dict ] = None
321328):
322329 model_config = model_config or {}
323330 hf_config = {}
@@ -336,7 +343,8 @@ def save_config_for_hf(
336343 if 'labels' in model_config :
337344 _logger .warning (
338345 "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
339- " Renaming provided 'labels' field to 'label_names'." )
346+ " Renaming provided 'labels' field to 'label_names'."
347+ )
340348 model_config .setdefault ('label_names' , model_config .pop ('labels' ))
341349
342350 label_names = model_config .pop ('label_names' , None )
@@ -363,11 +371,11 @@ def save_config_for_hf(
363371
364372
365373def save_for_hf (
366- model : torch .nn .Module ,
367- save_directory : str ,
368- model_config : Optional [dict ] = None ,
369- model_args : Optional [dict ] = None ,
370- safe_serialization : Union [bool , Literal ["both" ]] = False ,
374+ model : torch .nn .Module ,
375+ save_directory : str ,
376+ model_config : Optional [dict ] = None ,
377+ model_args : Optional [dict ] = None ,
378+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
371379):
372380 assert has_hf_hub (True )
373381 save_directory = Path (save_directory )
@@ -391,18 +399,18 @@ def save_for_hf(
391399
392400
393401def push_to_hf_hub (
394- model : torch .nn .Module ,
395- repo_id : str ,
396- commit_message : str = 'Add model' ,
397- token : Optional [str ] = None ,
398- revision : Optional [str ] = None ,
399- private : bool = False ,
400- create_pr : bool = False ,
401- model_config : Optional [dict ] = None ,
402- model_card : Optional [dict ] = None ,
403- model_args : Optional [dict ] = None ,
404- task_name : str = 'image-classification' ,
405- safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
402+ model : torch .nn .Module ,
403+ repo_id : str ,
404+ commit_message : str = 'Add model' ,
405+ token : Optional [str ] = None ,
406+ revision : Optional [str ] = None ,
407+ private : bool = False ,
408+ create_pr : bool = False ,
409+ model_config : Optional [dict ] = None ,
410+ model_card : Optional [dict ] = None ,
411+ model_args : Optional [dict ] = None ,
412+ task_name : str = 'image-classification' ,
413+ safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
406414):
407415 """
408416 Arguments:
@@ -452,9 +460,9 @@ def push_to_hf_hub(
452460
453461
454462def generate_readme (
455- model_card : dict ,
456- model_name : str ,
457- task_name : str = 'image-classification' ,
463+ model_card : dict ,
464+ model_name : str ,
465+ task_name : str = 'image-classification' ,
458466):
459467 tags = model_card .get ('tags' , None ) or [task_name , 'timm' , 'transformers' ]
460468 readme_text = "---\n "
0 commit comments