1717
1818try :
1919 import safetensors .torch
20-
2120 _has_safetensors = True
2221except ImportError :
2322 _has_safetensors = False
3231
3332try :
3433 from huggingface_hub import (
35- create_repo ,
36- get_hf_file_metadata ,
37- hf_hub_download ,
38- hf_hub_url ,
39- model_info ,
40- repo_type_and_id_from_hf_id ,
41- upload_folder ,
42- )
34+ create_repo , get_hf_file_metadata ,
35+ hf_hub_download , hf_hub_url , model_info ,
36+ repo_type_and_id_from_hf_id , upload_folder )
4337 from huggingface_hub .utils import EntryNotFoundError , RepositoryNotFoundError
44-
4538 hf_hub_download = partial (hf_hub_download , library_name = "timm" , library_version = __version__ )
4639 _has_hf_hub = True
4740except ImportError :
5043
5144_logger = logging .getLogger (__name__ )
5245
53- __all__ = [
54- 'get_cache_dir' ,
55- 'download_cached_file' ,
56- 'has_hf_hub' ,
57- 'hf_split' ,
58- 'load_model_config_from_hf' ,
59- 'load_state_dict_from_hf' ,
60- 'save_for_hf' ,
61- 'push_to_hf_hub' ,
62- ]
46+ __all__ = ['get_cache_dir' , 'download_cached_file' , 'has_hf_hub' , 'hf_split' , 'load_model_config_from_hf' ,
47+ 'load_state_dict_from_hf' , 'save_for_hf' , 'push_to_hf_hub' ]
6348
6449# Default name for a weights file hosted on the Huggingface Hub.
6550HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -84,10 +69,10 @@ def get_cache_dir(child_dir: str = ''):
8469
8570
8671def download_cached_file (
87- url : Union [str , List [str ], Tuple [str , str ]],
88- check_hash : bool = True ,
89- progress : bool = False ,
90- cache_dir : Optional [Union [str , Path ]] = None ,
72+ url : Union [str , List [str ], Tuple [str , str ]],
73+ check_hash : bool = True ,
74+ progress : bool = False ,
75+ cache_dir : Optional [Union [str , Path ]] = None ,
9176):
9277 if isinstance (url , (list , tuple )):
9378 url , filename = url
@@ -110,9 +95,9 @@ def download_cached_file(
11095
11196
11297def check_cached_file (
113- url : Union [str , List [str ], Tuple [str , str ]],
114- check_hash : bool = True ,
115- cache_dir : Optional [Union [str , Path ]] = None ,
98+ url : Union [str , List [str ], Tuple [str , str ]],
99+ check_hash : bool = True ,
100+ cache_dir : Optional [Union [str , Path ]] = None ,
116101):
117102 if isinstance (url , (list , tuple )):
118103 url , filename = url
@@ -129,7 +114,7 @@ def check_cached_file(
129114 if hash_prefix :
130115 with open (cached_file , 'rb' ) as f :
131116 hd = hashlib .sha256 (f .read ()).hexdigest ()
132- if hd [: len (hash_prefix )] != hash_prefix :
117+ if hd [:len (hash_prefix )] != hash_prefix :
133118 return False
134119 return True
135120 return False
@@ -139,8 +124,7 @@ def has_hf_hub(necessary: bool = False):
139124 if not _has_hf_hub and necessary :
140125 # if no HF Hub module installed, and it is necessary to continue, raise error
141126 raise RuntimeError (
142- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
143- )
127+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.' )
144128 return _has_hf_hub
145129
146130
@@ -160,9 +144,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
160144
161145
162146def download_from_hf (
163- model_id : str ,
164- filename : str ,
165- cache_dir : Optional [Union [str , Path ]] = None ,
147+ model_id : str ,
148+ filename : str ,
149+ cache_dir : Optional [Union [str , Path ]] = None ,
166150):
167151 hf_model_id , hf_revision = hf_split (model_id )
168152 return hf_hub_download (
@@ -174,8 +158,8 @@ def download_from_hf(
174158
175159
176160def _parse_model_cfg (
177- cfg : Dict [str , Any ],
178- extra_fields : Dict [str , Any ],
161+ cfg : Dict [str , Any ],
162+ extra_fields : Dict [str , Any ],
179163) -> Tuple [Dict [str , Any ], str , Dict [str , Any ]]:
180164 """"""
181165 # legacy "single‑dict" → split
@@ -186,7 +170,7 @@ def _parse_model_cfg(
186170 "num_features" : pretrained_cfg .pop ("num_features" , None ),
187171 "pretrained_cfg" : pretrained_cfg ,
188172 }
189- if "labels" in pretrained_cfg : # rename ‑‑> label_names
173+ if "labels" in pretrained_cfg : # rename ‑‑> label_names
190174 pretrained_cfg ["label_names" ] = pretrained_cfg .pop ("labels" )
191175
192176 pretrained_cfg = cfg ["pretrained_cfg" ]
@@ -206,8 +190,8 @@ def _parse_model_cfg(
206190
207191
208192def load_model_config_from_hf (
209- model_id : str ,
210- cache_dir : Optional [Union [str , Path ]] = None ,
193+ model_id : str ,
194+ cache_dir : Optional [Union [str , Path ]] = None ,
211195):
212196 """Original HF‑Hub loader (unchanged download, shared parsing)."""
213197 assert has_hf_hub (True )
@@ -217,7 +201,7 @@ def load_model_config_from_hf(
217201
218202
219203def load_model_config_from_path (
220- model_path : Union [str , Path ],
204+ model_path : Union [str , Path ],
221205):
222206 """Load from ``<model_path>/config.json`` on the local filesystem."""
223207 model_path = Path (model_path )
@@ -230,10 +214,10 @@ def load_model_config_from_path(
230214
231215
232216def load_state_dict_from_hf (
233- model_id : str ,
234- filename : str = HF_WEIGHTS_NAME ,
235- weights_only : bool = False ,
236- cache_dir : Optional [Union [str , Path ]] = None ,
217+ model_id : str ,
218+ filename : str = HF_WEIGHTS_NAME ,
219+ weights_only : bool = False ,
220+ cache_dir : Optional [Union [str , Path ]] = None ,
237221):
238222 assert has_hf_hub (True )
239223 hf_model_id , hf_revision = hf_split (model_id )
@@ -250,8 +234,7 @@ def load_state_dict_from_hf(
250234 )
251235 _logger .info (
252236 f"[{ model_id } ] Safe alternative available for '{ filename } ' "
253- f"(as '{ safe_filename } '). Loading weights using safetensors."
254- )
237+ f"(as '{ safe_filename } '). Loading weights using safetensors." )
255238 return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
256239 except EntryNotFoundError :
257240 pass
@@ -283,10 +266,9 @@ def load_state_dict_from_hf(
283266)
284267_EXT_PRIORITY = ('.safetensors' , '.pth' , '.pth.tar' , '.bin' )
285268
286-
287269def load_state_dict_from_path (
288- path : str ,
289- weights_only : bool = False ,
270+ path : str ,
271+ weights_only : bool = False ,
290272):
291273 found_file = None
292274 for fname in _PREFERRED_FILES :
@@ -301,7 +283,10 @@ def load_state_dict_from_path(
301283 files = sorted (path .glob (f"*{ ext } " ))
302284 if files :
303285 if len (files ) > 1 :
304- logging .warning (f"Multiple { ext } checkpoints in { path } : { names } . " f"Using '{ files [0 ].name } '." )
286+ logging .warning (
287+ f"Multiple { ext } checkpoints in { path } : { names } . "
288+ f"Using '{ files [0 ].name } '."
289+ )
305290 found_file = files [0 ]
306291
307292 if not found_file :
@@ -315,10 +300,10 @@ def load_state_dict_from_path(
315300
316301
317302def load_custom_from_hf (
318- model_id : str ,
319- filename : str ,
320- model : torch .nn .Module ,
321- cache_dir : Optional [Union [str , Path ]] = None ,
303+ model_id : str ,
304+ filename : str ,
305+ model : torch .nn .Module ,
306+ cache_dir : Optional [Union [str , Path ]] = None ,
322307):
323308 assert has_hf_hub (True )
324309 hf_model_id , hf_revision = hf_split (model_id )
@@ -332,7 +317,10 @@ def load_custom_from_hf(
332317
333318
334319def save_config_for_hf (
335- model : torch .nn .Module , config_path : str , model_config : Optional [dict ] = None , model_args : Optional [dict ] = None
320+ model : torch .nn .Module ,
321+ config_path : str ,
322+ model_config : Optional [dict ] = None ,
323+ model_args : Optional [dict ] = None
336324):
337325 model_config = model_config or {}
338326 hf_config = {}
@@ -351,8 +339,7 @@ def save_config_for_hf(
351339 if 'labels' in model_config :
352340 _logger .warning (
353341 "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
354- " Renaming provided 'labels' field to 'label_names'."
355- )
342+ " Renaming provided 'labels' field to 'label_names'." )
356343 model_config .setdefault ('label_names' , model_config .pop ('labels' ))
357344
358345 label_names = model_config .pop ('label_names' , None )
@@ -379,11 +366,11 @@ def save_config_for_hf(
379366
380367
381368def save_for_hf (
382- model : torch .nn .Module ,
383- save_directory : str ,
384- model_config : Optional [dict ] = None ,
385- model_args : Optional [dict ] = None ,
386- safe_serialization : Union [bool , Literal ["both" ]] = False ,
369+ model : torch .nn .Module ,
370+ save_directory : str ,
371+ model_config : Optional [dict ] = None ,
372+ model_args : Optional [dict ] = None ,
373+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
387374):
388375 assert has_hf_hub (True )
389376 save_directory = Path (save_directory )
@@ -407,18 +394,18 @@ def save_for_hf(
407394
408395
409396def push_to_hf_hub (
410- model : torch .nn .Module ,
411- repo_id : str ,
412- commit_message : str = 'Add model' ,
413- token : Optional [str ] = None ,
414- revision : Optional [str ] = None ,
415- private : bool = False ,
416- create_pr : bool = False ,
417- model_config : Optional [dict ] = None ,
418- model_card : Optional [dict ] = None ,
419- model_args : Optional [dict ] = None ,
420- task_name : str = 'image-classification' ,
421- safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
397+ model : torch .nn .Module ,
398+ repo_id : str ,
399+ commit_message : str = 'Add model' ,
400+ token : Optional [str ] = None ,
401+ revision : Optional [str ] = None ,
402+ private : bool = False ,
403+ create_pr : bool = False ,
404+ model_config : Optional [dict ] = None ,
405+ model_card : Optional [dict ] = None ,
406+ model_args : Optional [dict ] = None ,
407+ task_name : str = 'image-classification' ,
408+ safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
422409):
423410 """
424411 Arguments:
@@ -472,9 +459,9 @@ def push_to_hf_hub(
472459
473460
474461def generate_readme (
475- model_card : dict ,
476- model_name : str ,
477- task_name : str = 'image-classification' ,
462+ model_card : dict ,
463+ model_name : str ,
464+ task_name : str = 'image-classification' ,
478465):
479466 tags = model_card .get ('tags' , None ) or [task_name , 'timm' , 'transformers' ]
480467 readme_text = "---\n "
0 commit comments