Skip to content

Commit 09a13d8

Browse files
committed
Revert "Running formatting with command from CONTRIBUTING.md"
This reverts commit ed00d06. Reducing diff to keep pull request only for functional change.
1 parent ed00d06 commit 09a13d8

File tree

1 file changed

+63
-76
lines changed

1 file changed

+63
-76
lines changed

timm/models/_hub.py

Lines changed: 63 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
try:
1919
import safetensors.torch
20-
2120
_has_safetensors = True
2221
except ImportError:
2322
_has_safetensors = False
@@ -32,16 +31,10 @@
3231

3332
try:
3433
from huggingface_hub import (
35-
create_repo,
36-
get_hf_file_metadata,
37-
hf_hub_download,
38-
hf_hub_url,
39-
model_info,
40-
repo_type_and_id_from_hf_id,
41-
upload_folder,
42-
)
34+
create_repo, get_hf_file_metadata,
35+
hf_hub_download, hf_hub_url, model_info,
36+
repo_type_and_id_from_hf_id, upload_folder)
4337
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
44-
4538
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
4639
_has_hf_hub = True
4740
except ImportError:
@@ -50,16 +43,8 @@
5043

5144
_logger = logging.getLogger(__name__)
5245

53-
__all__ = [
54-
'get_cache_dir',
55-
'download_cached_file',
56-
'has_hf_hub',
57-
'hf_split',
58-
'load_model_config_from_hf',
59-
'load_state_dict_from_hf',
60-
'save_for_hf',
61-
'push_to_hf_hub',
62-
]
46+
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
47+
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
6348

6449
# Default name for a weights file hosted on the Huggingface Hub.
6550
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -84,10 +69,10 @@ def get_cache_dir(child_dir: str = ''):
8469

8570

8671
def download_cached_file(
87-
url: Union[str, List[str], Tuple[str, str]],
88-
check_hash: bool = True,
89-
progress: bool = False,
90-
cache_dir: Optional[Union[str, Path]] = None,
72+
url: Union[str, List[str], Tuple[str, str]],
73+
check_hash: bool = True,
74+
progress: bool = False,
75+
cache_dir: Optional[Union[str, Path]] = None,
9176
):
9277
if isinstance(url, (list, tuple)):
9378
url, filename = url
@@ -110,9 +95,9 @@ def download_cached_file(
11095

11196

11297
def check_cached_file(
113-
url: Union[str, List[str], Tuple[str, str]],
114-
check_hash: bool = True,
115-
cache_dir: Optional[Union[str, Path]] = None,
98+
url: Union[str, List[str], Tuple[str, str]],
99+
check_hash: bool = True,
100+
cache_dir: Optional[Union[str, Path]] = None,
116101
):
117102
if isinstance(url, (list, tuple)):
118103
url, filename = url
@@ -129,7 +114,7 @@ def check_cached_file(
129114
if hash_prefix:
130115
with open(cached_file, 'rb') as f:
131116
hd = hashlib.sha256(f.read()).hexdigest()
132-
if hd[: len(hash_prefix)] != hash_prefix:
117+
if hd[:len(hash_prefix)] != hash_prefix:
133118
return False
134119
return True
135120
return False
@@ -139,8 +124,7 @@ def has_hf_hub(necessary: bool = False):
139124
if not _has_hf_hub and necessary:
140125
# if no HF Hub module installed, and it is necessary to continue, raise error
141126
raise RuntimeError(
142-
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
143-
)
127+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
144128
return _has_hf_hub
145129

146130

@@ -160,9 +144,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
160144

161145

162146
def download_from_hf(
163-
model_id: str,
164-
filename: str,
165-
cache_dir: Optional[Union[str, Path]] = None,
147+
model_id: str,
148+
filename: str,
149+
cache_dir: Optional[Union[str, Path]] = None,
166150
):
167151
hf_model_id, hf_revision = hf_split(model_id)
168152
return hf_hub_download(
@@ -174,8 +158,8 @@ def download_from_hf(
174158

175159

176160
def _parse_model_cfg(
177-
cfg: Dict[str, Any],
178-
extra_fields: Dict[str, Any],
161+
cfg: Dict[str, Any],
162+
extra_fields: Dict[str, Any],
179163
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
180164
""""""
181165
# legacy "single‑dict" → split
@@ -186,7 +170,7 @@ def _parse_model_cfg(
186170
"num_features": pretrained_cfg.pop("num_features", None),
187171
"pretrained_cfg": pretrained_cfg,
188172
}
189-
if "labels" in pretrained_cfg: # rename ‑‑> label_names
173+
if "labels" in pretrained_cfg: # rename ‑‑> label_names
190174
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
191175

192176
pretrained_cfg = cfg["pretrained_cfg"]
@@ -206,8 +190,8 @@ def _parse_model_cfg(
206190

207191

208192
def load_model_config_from_hf(
209-
model_id: str,
210-
cache_dir: Optional[Union[str, Path]] = None,
193+
model_id: str,
194+
cache_dir: Optional[Union[str, Path]] = None,
211195
):
212196
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
213197
assert has_hf_hub(True)
@@ -217,7 +201,7 @@ def load_model_config_from_hf(
217201

218202

219203
def load_model_config_from_path(
220-
model_path: Union[str, Path],
204+
model_path: Union[str, Path],
221205
):
222206
"""Load from ``<model_path>/config.json`` on the local filesystem."""
223207
model_path = Path(model_path)
@@ -230,10 +214,10 @@ def load_model_config_from_path(
230214

231215

232216
def load_state_dict_from_hf(
233-
model_id: str,
234-
filename: str = HF_WEIGHTS_NAME,
235-
weights_only: bool = False,
236-
cache_dir: Optional[Union[str, Path]] = None,
217+
model_id: str,
218+
filename: str = HF_WEIGHTS_NAME,
219+
weights_only: bool = False,
220+
cache_dir: Optional[Union[str, Path]] = None,
237221
):
238222
assert has_hf_hub(True)
239223
hf_model_id, hf_revision = hf_split(model_id)
@@ -250,8 +234,7 @@ def load_state_dict_from_hf(
250234
)
251235
_logger.info(
252236
f"[{model_id}] Safe alternative available for '{filename}' "
253-
f"(as '{safe_filename}'). Loading weights using safetensors."
254-
)
237+
f"(as '{safe_filename}'). Loading weights using safetensors.")
255238
return safetensors.torch.load_file(cached_safe_file, device="cpu")
256239
except EntryNotFoundError:
257240
pass
@@ -283,10 +266,9 @@ def load_state_dict_from_hf(
283266
)
284267
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
285268

286-
287269
def load_state_dict_from_path(
288-
path: str,
289-
weights_only: bool = False,
270+
path: str,
271+
weights_only: bool = False,
290272
):
291273
found_file = None
292274
for fname in _PREFERRED_FILES:
@@ -301,7 +283,10 @@ def load_state_dict_from_path(
301283
files = sorted(path.glob(f"*{ext}"))
302284
if files:
303285
if len(files) > 1:
304-
logging.warning(f"Multiple {ext} checkpoints in {path}: {names}. " f"Using '{files[0].name}'.")
286+
logging.warning(
287+
f"Multiple {ext} checkpoints in {path}: {names}. "
288+
f"Using '{files[0].name}'."
289+
)
305290
found_file = files[0]
306291

307292
if not found_file:
@@ -315,10 +300,10 @@ def load_state_dict_from_path(
315300

316301

317302
def load_custom_from_hf(
318-
model_id: str,
319-
filename: str,
320-
model: torch.nn.Module,
321-
cache_dir: Optional[Union[str, Path]] = None,
303+
model_id: str,
304+
filename: str,
305+
model: torch.nn.Module,
306+
cache_dir: Optional[Union[str, Path]] = None,
322307
):
323308
assert has_hf_hub(True)
324309
hf_model_id, hf_revision = hf_split(model_id)
@@ -332,7 +317,10 @@ def load_custom_from_hf(
332317

333318

334319
def save_config_for_hf(
335-
model: torch.nn.Module, config_path: str, model_config: Optional[dict] = None, model_args: Optional[dict] = None
320+
model: torch.nn.Module,
321+
config_path: str,
322+
model_config: Optional[dict] = None,
323+
model_args: Optional[dict] = None
336324
):
337325
model_config = model_config or {}
338326
hf_config = {}
@@ -351,8 +339,7 @@ def save_config_for_hf(
351339
if 'labels' in model_config:
352340
_logger.warning(
353341
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
354-
" Renaming provided 'labels' field to 'label_names'."
355-
)
342+
" Renaming provided 'labels' field to 'label_names'.")
356343
model_config.setdefault('label_names', model_config.pop('labels'))
357344

358345
label_names = model_config.pop('label_names', None)
@@ -379,11 +366,11 @@ def save_config_for_hf(
379366

380367

381368
def save_for_hf(
382-
model: torch.nn.Module,
383-
save_directory: str,
384-
model_config: Optional[dict] = None,
385-
model_args: Optional[dict] = None,
386-
safe_serialization: Union[bool, Literal["both"]] = False,
369+
model: torch.nn.Module,
370+
save_directory: str,
371+
model_config: Optional[dict] = None,
372+
model_args: Optional[dict] = None,
373+
safe_serialization: Union[bool, Literal["both"]] = False,
387374
):
388375
assert has_hf_hub(True)
389376
save_directory = Path(save_directory)
@@ -407,18 +394,18 @@ def save_for_hf(
407394

408395

409396
def push_to_hf_hub(
410-
model: torch.nn.Module,
411-
repo_id: str,
412-
commit_message: str = 'Add model',
413-
token: Optional[str] = None,
414-
revision: Optional[str] = None,
415-
private: bool = False,
416-
create_pr: bool = False,
417-
model_config: Optional[dict] = None,
418-
model_card: Optional[dict] = None,
419-
model_args: Optional[dict] = None,
420-
task_name: str = 'image-classification',
421-
safe_serialization: Union[bool, Literal["both"]] = 'both',
397+
model: torch.nn.Module,
398+
repo_id: str,
399+
commit_message: str = 'Add model',
400+
token: Optional[str] = None,
401+
revision: Optional[str] = None,
402+
private: bool = False,
403+
create_pr: bool = False,
404+
model_config: Optional[dict] = None,
405+
model_card: Optional[dict] = None,
406+
model_args: Optional[dict] = None,
407+
task_name: str = 'image-classification',
408+
safe_serialization: Union[bool, Literal["both"]] = 'both',
422409
):
423410
"""
424411
Arguments:
@@ -472,9 +459,9 @@ def push_to_hf_hub(
472459

473460

474461
def generate_readme(
475-
model_card: dict,
476-
model_name: str,
477-
task_name: str = 'image-classification',
462+
model_card: dict,
463+
model_name: str,
464+
task_name: str = 'image-classification',
478465
):
479466
tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers']
480467
readme_text = "---\n"

0 commit comments

Comments
 (0)