Skip to content

Commit ffddf59

Browse files
committed
Running formatting with command from CONTRIBUTING.md
- Skipping the cspnet.py file with formatting due to large diff
1 parent 1228519 commit ffddf59

File tree

1 file changed

+68
-60
lines changed

1 file changed

+68
-60
lines changed

timm/models/_hub.py

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
try:
1919
import safetensors.torch
20+
2021
_has_safetensors = True
2122
except ImportError:
2223
_has_safetensors = False
@@ -32,6 +33,7 @@
3233
try:
3334
from huggingface_hub import HfApi, hf_hub_download, model_info
3435
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
36+
3537
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
3638
_has_hf_hub = True
3739
except ImportError:
@@ -40,8 +42,16 @@
4042

4143
_logger = logging.getLogger(__name__)
4244

43-
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
44-
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
45+
__all__ = [
46+
'get_cache_dir',
47+
'download_cached_file',
48+
'has_hf_hub',
49+
'hf_split',
50+
'load_model_config_from_hf',
51+
'load_state_dict_from_hf',
52+
'save_for_hf',
53+
'push_to_hf_hub',
54+
]
4555

4656
# Default name for a weights file hosted on the Huggingface Hub.
4757
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -66,10 +76,10 @@ def get_cache_dir(child_dir: str = ''):
6676

6777

6878
def download_cached_file(
69-
url: Union[str, List[str], Tuple[str, str]],
70-
check_hash: bool = True,
71-
progress: bool = False,
72-
cache_dir: Optional[Union[str, Path]] = None,
79+
url: Union[str, List[str], Tuple[str, str]],
80+
check_hash: bool = True,
81+
progress: bool = False,
82+
cache_dir: Optional[Union[str, Path]] = None,
7383
):
7484
if isinstance(url, (list, tuple)):
7585
url, filename = url
@@ -92,9 +102,9 @@ def download_cached_file(
92102

93103

94104
def check_cached_file(
95-
url: Union[str, List[str], Tuple[str, str]],
96-
check_hash: bool = True,
97-
cache_dir: Optional[Union[str, Path]] = None,
105+
url: Union[str, List[str], Tuple[str, str]],
106+
check_hash: bool = True,
107+
cache_dir: Optional[Union[str, Path]] = None,
98108
):
99109
if isinstance(url, (list, tuple)):
100110
url, filename = url
@@ -111,7 +121,7 @@ def check_cached_file(
111121
if hash_prefix:
112122
with open(cached_file, 'rb') as f:
113123
hd = hashlib.sha256(f.read()).hexdigest()
114-
if hd[:len(hash_prefix)] != hash_prefix:
124+
if hd[: len(hash_prefix)] != hash_prefix:
115125
return False
116126
return True
117127
return False
@@ -121,7 +131,8 @@ def has_hf_hub(necessary: bool = False):
121131
if not _has_hf_hub and necessary:
122132
# if no HF Hub module installed, and it is necessary to continue, raise error
123133
raise RuntimeError(
124-
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
134+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
135+
)
125136
return _has_hf_hub
126137

127138

@@ -141,9 +152,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
141152

142153

143154
def download_from_hf(
144-
model_id: str,
145-
filename: str,
146-
cache_dir: Optional[Union[str, Path]] = None,
155+
model_id: str,
156+
filename: str,
157+
cache_dir: Optional[Union[str, Path]] = None,
147158
):
148159
hf_model_id, hf_revision = hf_split(model_id)
149160
return hf_hub_download(
@@ -155,8 +166,8 @@ def download_from_hf(
155166

156167

157168
def _parse_model_cfg(
158-
cfg: Dict[str, Any],
159-
extra_fields: Dict[str, Any],
169+
cfg: Dict[str, Any],
170+
extra_fields: Dict[str, Any],
160171
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
161172
""""""
162173
# legacy "single‑dict" → split
@@ -167,7 +178,7 @@ def _parse_model_cfg(
167178
"num_features": pretrained_cfg.pop("num_features", None),
168179
"pretrained_cfg": pretrained_cfg,
169180
}
170-
if "labels" in pretrained_cfg: # rename ‑‑> label_names
181+
if "labels" in pretrained_cfg: # rename ‑‑> label_names
171182
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
172183

173184
pretrained_cfg = cfg["pretrained_cfg"]
@@ -187,8 +198,8 @@ def _parse_model_cfg(
187198

188199

189200
def load_model_config_from_hf(
190-
model_id: str,
191-
cache_dir: Optional[Union[str, Path]] = None,
201+
model_id: str,
202+
cache_dir: Optional[Union[str, Path]] = None,
192203
):
193204
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
194205
assert has_hf_hub(True)
@@ -198,7 +209,7 @@ def load_model_config_from_hf(
198209

199210

200211
def load_model_config_from_path(
201-
model_path: Union[str, Path],
212+
model_path: Union[str, Path],
202213
):
203214
"""Load from ``<model_path>/config.json`` on the local filesystem."""
204215
model_path = Path(model_path)
@@ -211,10 +222,10 @@ def load_model_config_from_path(
211222

212223

213224
def load_state_dict_from_hf(
214-
model_id: str,
215-
filename: str = HF_WEIGHTS_NAME,
216-
weights_only: bool = False,
217-
cache_dir: Optional[Union[str, Path]] = None,
225+
model_id: str,
226+
filename: str = HF_WEIGHTS_NAME,
227+
weights_only: bool = False,
228+
cache_dir: Optional[Union[str, Path]] = None,
218229
):
219230
assert has_hf_hub(True)
220231
hf_model_id, hf_revision = hf_split(model_id)
@@ -231,7 +242,8 @@ def load_state_dict_from_hf(
231242
)
232243
_logger.info(
233244
f"[{model_id}] Safe alternative available for '{filename}' "
234-
f"(as '{safe_filename}'). Loading weights using safetensors.")
245+
f"(as '{safe_filename}'). Loading weights using safetensors."
246+
)
235247
return safetensors.torch.load_file(cached_safe_file, device="cpu")
236248
except EntryNotFoundError:
237249
pass
@@ -263,9 +275,10 @@ def load_state_dict_from_hf(
263275
)
264276
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
265277

278+
266279
def load_state_dict_from_path(
267-
path: str,
268-
weights_only: bool = False,
280+
path: str,
281+
weights_only: bool = False,
269282
):
270283
found_file = None
271284
for fname in _PREFERRED_FILES:
@@ -280,10 +293,7 @@ def load_state_dict_from_path(
280293
files = sorted(path.glob(f"*{ext}"))
281294
if files:
282295
if len(files) > 1:
283-
logging.warning(
284-
f"Multiple {ext} checkpoints in {path}: {names}. "
285-
f"Using '{files[0].name}'."
286-
)
296+
logging.warning(f"Multiple {ext} checkpoints in {path}: {names}. " f"Using '{files[0].name}'.")
287297
found_file = files[0]
288298

289299
if not found_file:
@@ -297,10 +307,10 @@ def load_state_dict_from_path(
297307

298308

299309
def load_custom_from_hf(
300-
model_id: str,
301-
filename: str,
302-
model: torch.nn.Module,
303-
cache_dir: Optional[Union[str, Path]] = None,
310+
model_id: str,
311+
filename: str,
312+
model: torch.nn.Module,
313+
cache_dir: Optional[Union[str, Path]] = None,
304314
):
305315
assert has_hf_hub(True)
306316
hf_model_id, hf_revision = hf_split(model_id)
@@ -314,10 +324,7 @@ def load_custom_from_hf(
314324

315325

316326
def save_config_for_hf(
317-
model: torch.nn.Module,
318-
config_path: str,
319-
model_config: Optional[dict] = None,
320-
model_args: Optional[dict] = None
327+
model: torch.nn.Module, config_path: str, model_config: Optional[dict] = None, model_args: Optional[dict] = None
321328
):
322329
model_config = model_config or {}
323330
hf_config = {}
@@ -336,7 +343,8 @@ def save_config_for_hf(
336343
if 'labels' in model_config:
337344
_logger.warning(
338345
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
339-
" Renaming provided 'labels' field to 'label_names'.")
346+
" Renaming provided 'labels' field to 'label_names'."
347+
)
340348
model_config.setdefault('label_names', model_config.pop('labels'))
341349

342350
label_names = model_config.pop('label_names', None)
@@ -363,11 +371,11 @@ def save_config_for_hf(
363371

364372

365373
def save_for_hf(
366-
model: torch.nn.Module,
367-
save_directory: str,
368-
model_config: Optional[dict] = None,
369-
model_args: Optional[dict] = None,
370-
safe_serialization: Union[bool, Literal["both"]] = False,
374+
model: torch.nn.Module,
375+
save_directory: str,
376+
model_config: Optional[dict] = None,
377+
model_args: Optional[dict] = None,
378+
safe_serialization: Union[bool, Literal["both"]] = False,
371379
):
372380
assert has_hf_hub(True)
373381
save_directory = Path(save_directory)
@@ -391,18 +399,18 @@ def save_for_hf(
391399

392400

393401
def push_to_hf_hub(
394-
model: torch.nn.Module,
395-
repo_id: str,
396-
commit_message: str = 'Add model',
397-
token: Optional[str] = None,
398-
revision: Optional[str] = None,
399-
private: bool = False,
400-
create_pr: bool = False,
401-
model_config: Optional[dict] = None,
402-
model_card: Optional[dict] = None,
403-
model_args: Optional[dict] = None,
404-
task_name: str = 'image-classification',
405-
safe_serialization: Union[bool, Literal["both"]] = 'both',
402+
model: torch.nn.Module,
403+
repo_id: str,
404+
commit_message: str = 'Add model',
405+
token: Optional[str] = None,
406+
revision: Optional[str] = None,
407+
private: bool = False,
408+
create_pr: bool = False,
409+
model_config: Optional[dict] = None,
410+
model_card: Optional[dict] = None,
411+
model_args: Optional[dict] = None,
412+
task_name: str = 'image-classification',
413+
safe_serialization: Union[bool, Literal["both"]] = 'both',
406414
):
407415
"""
408416
Arguments:
@@ -452,9 +460,9 @@ def push_to_hf_hub(
452460

453461

454462
def generate_readme(
455-
model_card: dict,
456-
model_name: str,
457-
task_name: str = 'image-classification',
463+
model_card: dict,
464+
model_name: str,
465+
task_name: str = 'image-classification',
458466
):
459467
tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers']
460468
readme_text = "---\n"

0 commit comments

Comments
 (0)