Skip to content

Commit aa50449

Browse files
Modernize type hints in logarithmic equalization modifier (#2121)
SUMMARY: Part of #1927 - Updated type hints to Python 3.10+ built-in generics - Replaced List[] with list[] - No functional changes TEST PLAN: - Ran `make quality` (ruff format and lint checks) - Verified no functional code changes were introduced --------- Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent 129c793 commit aa50449

File tree

2 files changed

+14
-16
lines changed
  • src/llmcompressor/modifiers

2 files changed

+14
-16
lines changed

src/llmcompressor/modifiers/logarithmic_equalization/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import List
2-
31
import torch
42
from torch.nn import Module
53

@@ -52,8 +50,8 @@ class LogarithmicEqualizationModifier(SmoothQuantModifier):
5250
"""
5351

5452
def _calculate_smoothing_scales(
55-
self, balance_layers: List[Module], activation_scales: torch.Tensor
56-
) -> List[float]:
53+
self, balance_layers: list[Module], activation_scales: torch.Tensor
54+
) -> torch.Tensor:
5755
"""
5856
Calculate how much smoothing to apply to each channel based on the dynamic
5957
range of the activations and the following weights.

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Callable, Dict, List, Optional, Tuple, Union
2+
from typing import Callable
33

44
import torch
55
from compressed_tensors.utils import align_module_device, match_modules_set
@@ -51,7 +51,7 @@ class SmoothQuantMapping:
5151

5252
smooth_name: str
5353
smooth_layer: Module
54-
balance_layers: List[Module]
54+
balance_layers: list[Module]
5555

5656

5757
class SmoothQuantModifier(Modifier):
@@ -96,15 +96,15 @@ class SmoothQuantModifier(Modifier):
9696
"""
9797

9898
smoothing_strength: float = 0.5
99-
mappings: Optional[List[Union[Tuple, List]]] = None
100-
ignore: Optional[List[str]] = None
101-
num_calibration_steps: Optional[int] = None
102-
calibration_function: Optional[Callable] = None
99+
mappings: list[tuple | list] | None = None
100+
ignore: list[str] | None = None
101+
num_calibration_steps: int | None = None
102+
calibration_function: Callable | None = None
103103

104-
resolved_mappings_: Optional[List[SmoothQuantMapping]] = Field(
104+
resolved_mappings_: list[SmoothQuantMapping] | None = Field(
105105
default=None, repr=False
106106
)
107-
scales_: Optional[Dict] = Field(default=None, repr=False)
107+
scales_: dict | None = Field(default=None, repr=False)
108108

109109
def on_initialize(self, state: State, **kwargs) -> bool:
110110
"""
@@ -178,7 +178,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
178178
def _infer_mappings_from_model(
179179
self,
180180
model: Module,
181-
) -> List[Tuple]:
181+
) -> list[tuple]:
182182
if self.mappings is not None:
183183
return self.mappings
184184

@@ -188,7 +188,7 @@ def _infer_mappings_from_model(
188188
)
189189

190190
@handle_mapping_resolution_errors
191-
def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
191+
def _resolve_mappings(self, model: Module) -> list[SmoothQuantMapping]:
192192
"""
193193
Transforms the list of activations to smooth and their corresponding weights
194194
into SmoothQuantMapping objects, resolving regular expressions.
@@ -309,8 +309,8 @@ def smooth(module):
309309
del self.scales_[mapping.smooth_name]
310310

311311
def _calculate_smoothing_scales(
312-
self, balance_layers: List[Module], activation_scales: torch.Tensor
313-
) -> List[float]:
312+
self, balance_layers: list[Module], activation_scales: torch.Tensor
313+
) -> torch.Tensor:
314314
"""
315315
Calculate how much smoothing to apply to each channel based on the dynamic
316316
range of the activation and the following weights

0 commit comments

Comments
 (0)