Skip to content

Commit 3f25fd1

Browse files
Modernize transformers module with type hints and generic types (#2034)
SUMMARY: This is part of #1927 Modernize type annotations using | operator and built-in generics in the transformer module as part of codebase modernization effort. TEST PLAN: ``` make style make quality make tests ``` Notes: Happy to address any comments! Thank you! --------- Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent aa50449 commit 3f25fd1

File tree

9 files changed

+49
-51
lines changed

9 files changed

+49
-51
lines changed

src/llmcompressor/transformers/compression/compressed_tensors_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import weakref
33
from functools import wraps
4-
from typing import Optional
54

65
import torch
76
from accelerate.accelerator import get_state_dict_offloaded_model
@@ -50,8 +49,8 @@ def save_pretrained_compressed(save_pretrained_method):
5049
@wraps(original_save_pretrained)
5150
def save_pretrained_wrapper(
5251
save_directory: str,
53-
sparsity_config: Optional[SparsityCompressionConfig] = None,
54-
quantization_format: Optional[str] = None,
52+
sparsity_config: SparsityCompressionConfig | None = None,
53+
quantization_format: str | None = None,
5554
save_compressed: bool = True,
5655
safe_serialization: bool = True,
5756
skip_sparsity_compression_stats: bool = True,
@@ -116,8 +115,8 @@ def save_pretrained_wrapper(
116115

117116
def get_model_compressor(
118117
model: torch.nn.Module,
119-
sparsity_config: Optional[SparsityCompressionConfig] = None,
120-
quantization_format: Optional[str] = None,
118+
sparsity_config: SparsityCompressionConfig | None = None,
119+
quantization_format: str | None = None,
121120
save_compressed: bool = True,
122121
skip_sparsity_compression_stats: bool = True,
123122
disable_sparse_compression: bool = False,

src/llmcompressor/transformers/compression/helpers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections import defaultdict
2-
from typing import Dict, List, Optional, Tuple
32

43
import torch
54
from accelerate.accelerator import get_state_dict_offloaded_model
@@ -51,8 +50,8 @@ def tensor_follows_mask_structure(tensor: torch.Tensor, mask: str = "2:4") -> bo
5150

5251

5352
def infer_sparsity_structure_from_modifiers(
54-
modifiers: List[Modifier], # noqa E501
55-
) -> Optional[str]:
53+
modifiers: list[Modifier], # noqa E501
54+
) -> str | None:
5655
"""
5756
Determines the sparsity structure, if any exists, given the list of modifiers.
5857
@@ -65,7 +64,7 @@ def infer_sparsity_structure_from_modifiers(
6564
return None
6665

6766

68-
def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]:
67+
def infer_sparsity_structure_from_model(model: torch.nn.Module) -> str | None:
6968
"""
7069
Determines the sparsity structure, if any exists, given the model
7170
@@ -104,7 +103,7 @@ def infer_sparse_targets_and_ignores(
104103
model: torch.nn.Module,
105104
sparsity_structure: str,
106105
sparsity_threshold: float,
107-
) -> Tuple[List[str], List[str]]:
106+
) -> tuple[list[str], list[str]]:
108107
"""
109108
Infers the target and ignore layers in the given model
110109
to be used for sparsity compression
@@ -151,7 +150,7 @@ def is_sparse_compression_target(
151150

152151
def _get_sparse_targets_ignore_dicts(
153152
module: torch.nn.Module, sparsity_structure: str, sparsity_threshold: float
154-
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
153+
) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
155154
"""
156155
Get sparse targets and ignore dictionaries
157156
@@ -176,8 +175,8 @@ def _get_sparse_targets_ignore_dicts(
176175

177176

178177
def _reduce_targets_and_ignores_into_lists(
179-
exhaustive_targets: Dict[str, List[str]], exhaustive_ignore: Dict[str, List[str]]
180-
) -> Tuple[List[str], List[str]]:
178+
exhaustive_targets: dict[str, list[str]], exhaustive_ignore: dict[str, list[str]]
179+
) -> tuple[list[str], list[str]]:
181180
"""
182181
Reduces the targets and ignores dictionaries into lists
183182

src/llmcompressor/transformers/compression/sparsity_metadata_config.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional
1+
from __future__ import annotations
22

33
from compressed_tensors import CompressionFormat, SparsityCompressionConfig
44
from compressed_tensors.config import SparsityStructure
@@ -30,7 +30,7 @@ class SparsityConfigMetadata:
3030

3131
@staticmethod
3232
def infer_global_sparsity(
33-
model: Module, state_dict: Optional[Dict[str, Tensor]] = None
33+
model: Module, state_dict: dict[str, Tensor] | None = None
3434
) -> float:
3535
"""
3636
Calculates the global percentage of sparse zero weights in the model
@@ -47,12 +47,12 @@ def infer_global_sparsity(
4747

4848
@staticmethod
4949
def infer_sparsity_structure(
50-
model: Optional[Module] = None, check_only_modifiers: Optional[bool] = False
50+
model: Module | None = None, check_only_modifiers: bool | None = False
5151
) -> str:
5252
"""
5353
Determines what sparsity structure, if any, was applied.
5454
55-
First, there is an attempt to dedue the sparsity structure
55+
First, there is an attempt to deduce the sparsity structure
5656
from the currently active sparse session.
5757
5858
If that fails, the sparsity structure is inferred from the
@@ -83,12 +83,12 @@ def infer_sparsity_structure(
8383
@staticmethod
8484
def from_pretrained(
8585
model: Module,
86-
state_dict: Optional[Dict[str, Tensor]] = None,
86+
state_dict: dict[str, Tensor] | None = None,
8787
compress: bool = False,
88-
quantization_format: Optional[CompressionFormat] = None,
88+
quantization_format: CompressionFormat | None = None,
8989
disable_sparse_compression: bool = False,
90-
sparsity_structure: Optional[str] = None,
91-
) -> Optional["SparsityCompressionConfig"]:
90+
sparsity_structure: str | None = None,
91+
) -> SparsityCompressionConfig | None:
9292
"""
9393
Determines compression type and informational parameters for a given model
9494
@@ -155,7 +155,7 @@ def from_pretrained(
155155
def fill_config_details(
156156
config: SparsityCompressionConfig,
157157
model: Module,
158-
state_dict: Optional[Dict[str, Tensor]] = None,
158+
state_dict: dict[str, Tensor] | None = None,
159159
):
160160
"""
161161
Fills in informational sparsity parameters from a given model
@@ -173,7 +173,7 @@ def fill_config_details(
173173
@staticmethod
174174
def is_sparse24_bitmask_supported(
175175
model: Module,
176-
sparsity_structure: Optional[str] = None,
176+
sparsity_structure: str | None = None,
177177
) -> bool:
178178
"""
179179
Determines if sparse 24 bitmask sparse compressor is supported for a given model
@@ -202,7 +202,7 @@ def is_sparse24_bitmask_supported(
202202

203203
# when model is quantized, and has 2:4 sparsity
204204

205-
supported_scheme_types: List[str] = [
205+
supported_scheme_types: list[str] = [
206206
QuantizationType.INT.value,
207207
QuantizationType.FLOAT.value,
208208
]

src/llmcompressor/transformers/data/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import inspect
1111
from functools import cached_property
1212
from inspect import _ParameterKind as Kind
13-
from typing import Any, Callable, Dict, List, Union
13+
from typing import Any, Callable
1414

1515
from compressed_tensors.registry import RegistryMixin
1616
from datasets import Dataset, IterableDataset
@@ -203,7 +203,7 @@ def load_dataset(self):
203203
)
204204

205205
@cached_property
206-
def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
206+
def preprocess(self) -> Callable[[LazyRow], Any] | None:
207207
"""
208208
The function must return keys which correspond to processor/tokenizer kwargs,
209209
optionally including PROMPT_KEY
@@ -226,7 +226,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
226226
return self.dataset_template
227227

228228
@property
229-
def dataset_template(self) -> Union[Callable[[Any], Any], None]:
229+
def dataset_template(self) -> Callable[[Any], Any] | None:
230230
return None
231231

232232
def rename_columns(self, dataset: DatasetType) -> DatasetType:
@@ -255,7 +255,7 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
255255
list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY]))
256256
)
257257

258-
def tokenize(self, data: LazyRow) -> Dict[str, Any]:
258+
def tokenize(self, data: LazyRow) -> dict[str, Any]:
259259
# separate prompt
260260
prompt = data.pop(self.PROMPT_KEY, None)
261261

@@ -283,7 +283,7 @@ def tokenize(self, data: LazyRow) -> Dict[str, Any]:
283283

284284
return data
285285

286-
def group_text(self, data: LazyRow) -> Dict[str, Any]:
286+
def group_text(self, data: LazyRow) -> dict[str, Any]:
287287
concatenated_data = {k: sum(data[k], []) for k in data.keys()}
288288
total_length = len(concatenated_data[list(data.keys())[0]])
289289
total_length = (total_length // self.max_seq_length) * self.max_seq_length
@@ -318,10 +318,10 @@ def add_labels(self, data: LazyRow) -> LazyRow:
318318

319319
def map(
320320
self,
321-
dataset: Union[Dataset, IterableDataset],
321+
dataset: Dataset | IterableDataset,
322322
function: Callable[[Any], Any],
323323
**kwargs,
324-
) -> Union[Dataset, IterableDataset]:
324+
) -> Dataset | IterableDataset:
325325
"""
326326
Wrapper function around Dataset.map and IterableDataset.map.
327327
@@ -343,7 +343,7 @@ def map(
343343
return dataset
344344

345345

346-
def get_columns(dataset: DatasetType) -> List[str]:
346+
def get_columns(dataset: DatasetType) -> list[str]:
347347
column_names = dataset.column_names
348348
if isinstance(column_names, dict):
349349
column_names = sum(column_names.values(), [])

src/llmcompressor/transformers/data/data_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import os
3-
from typing import Any, Dict, Optional
3+
from typing import Any
44

55
from datasets import Dataset, load_dataset
66

@@ -15,8 +15,8 @@
1515

1616
def get_raw_dataset(
1717
dataset_args,
18-
cache_dir: Optional[str] = None,
19-
streaming: Optional[bool] = False,
18+
cache_dir: str | None = None,
19+
streaming: bool | None = False,
2020
**kwargs,
2121
) -> Dataset:
2222
"""
@@ -37,7 +37,7 @@ def get_raw_dataset(
3737
return raw_datasets
3838

3939

40-
def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str]:
40+
def get_custom_datasets_from_path(path: str, ext: str = "json") -> dict[str, str]:
4141
"""
4242
Get a dictionary of custom datasets from a directory path. Support HF's load_dataset
4343
for local folder datasets https://huggingface.co/docs/datasets/loading
@@ -105,7 +105,7 @@ def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str
105105
return transform_dataset_keys(data_files)
106106

107107

108-
def transform_dataset_keys(data_files: Dict[str, Any]):
108+
def transform_dataset_keys(data_files: dict[str, Any]):
109109
"""
110110
Transform dict keys to `train`, `val` or `test` for the given input dict
111111
if matches exist with the existing keys. Note that there can only be one

src/llmcompressor/transformers/data/peoples_speech.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from copy import deepcopy
2-
from typing import TYPE_CHECKING, Any, Dict
2+
from typing import TYPE_CHECKING, Any
33

44
from datasets.formatting.formatting import LazyRow
55
from loguru import logger
@@ -68,7 +68,7 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
6868
else:
6969
return super().filter_tokenizer_args(dataset)
7070

71-
def tokenize(self, data: LazyRow) -> Dict[str, Any]:
71+
def tokenize(self, data: LazyRow) -> dict[str, Any]:
7272
if self.processor_type == "WhisperProcessor":
7373
inputs = self.processor(
7474
audio=data["audio"],

src/llmcompressor/transformers/tracing/debug.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Type, Union, Optional, Dict, Tuple, Any
1+
from typing import Type, Tuple, Any
22

33
import argparse
44
from contextlib import nullcontext
@@ -33,13 +33,13 @@ def parse_args():
3333
def trace(
3434
model_id: str,
3535
model_class: Type[PreTrainedModel],
36-
sequential_targets: Optional[Union[List[str], str]] = None,
37-
ignore: Union[List[str], str] = DatasetArguments().tracing_ignore,
36+
sequential_targets: list[str] | str | None = None,
37+
ignore: list[str] | str = DatasetArguments().tracing_ignore,
3838
modality: str = "text",
3939
trust_remote_code: bool = True,
4040
skip_weights: bool = True,
41-
device_map: Union[str, Dict] = "cpu",
42-
) -> Tuple[PreTrainedModel, List[Subgraph], Dict[str, torch.Tensor]]:
41+
device_map: str | dict = "cpu",
42+
) -> Tuple[PreTrainedModel, list[Subgraph], dict[str, torch.Tensor]]:
4343
"""
4444
Debug traceability by tracing a pre-trained model into subgraphs
4545
@@ -110,7 +110,7 @@ def trace(
110110
return model, subgraphs, sample
111111

112112

113-
def get_dataset_kwargs(modality: str, ignore: List[str]) -> Dict[str, str]:
113+
def get_dataset_kwargs(modality: str, ignore: list[str]) -> dict[str, str]:
114114
dataset_kwargs = {
115115
"text": {
116116
"dataset": "ultrachat-200k",
@@ -139,7 +139,7 @@ def get_dataset_kwargs(modality: str, ignore: List[str]) -> Dict[str, str]:
139139
return dataset_kwargs[modality] | common_kwargs
140140

141141

142-
def collate_sample(sample: Dict[str, Any], device: str) -> Dict[str, torch.Tensor]:
142+
def collate_sample(sample: dict[str, Any], device: str) -> dict[str, torch.Tensor]:
143143
for name, value in sample.items():
144144
if name in ("input_ids", "attention_mask") and torch.tensor(value).ndim == 1:
145145
sample[name] = torch.tensor([value], device=device)

src/llmcompressor/transformers/utils/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import os
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional, Union
8+
from typing import TYPE_CHECKING
99

1010
import requests
1111
from huggingface_hub import (
@@ -47,7 +47,7 @@ def is_model_ct_quantized_from_path(path: str) -> bool:
4747
return False
4848

4949

50-
def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]:
50+
def infer_recipe_from_model_path(model_path: str | Path) -> str | None:
5151
"""
5252
Infer the recipe from the model_path.
5353
@@ -100,7 +100,7 @@ def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]:
100100

101101
def recipe_from_huggingface_model_id(
102102
hf_stub: str, recipe_file_name: str = RECIPE_FILE_NAME
103-
) -> Optional[str]:
103+
) -> str | None:
104104
"""
105105
Attempts to download the recipe from the Hugging Face model ID.
106106

src/llmcompressor/transformers/utils/preprocessing_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
popular training datasets.
88
"""
99

10-
from typing import TYPE_CHECKING, Dict
10+
from typing import TYPE_CHECKING
1111

1212
from compressed_tensors.registry import RegistryMixin
1313

@@ -20,7 +20,7 @@ class PreprocessingFunctionRegistry(RegistryMixin):
2020

2121

2222
@PreprocessingFunctionRegistry.register()
23-
def custom_evolved_codealpaca_dataset(self: "TextGenerationDataset", data: Dict):
23+
def custom_evolved_codealpaca_dataset(self: "TextGenerationDataset", data: dict):
2424
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
2525
data["prompt"] = PROMPT_DICT.format_map(data)
2626
data["text"] = data["prompt"] + data["output"]

0 commit comments

Comments
 (0)