diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 59678d5a..6c6a43cb 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -47,6 +47,10 @@ jobs: commands: | echo "::add-matcher::.github/workflows/matchers/pylint.json" tox -e lint + - name: "mypy" + commands: | + echo "::add-matcher::.github/workflows/matchers/mypy.json" + tox -e mypy steps: - name: "Harden Runner" diff --git a/.github/workflows/matchers/mypy.json b/.github/workflows/matchers/mypy.json new file mode 100644 index 00000000..f048fce5 --- /dev/null +++ b/.github/workflows/matchers/mypy.json @@ -0,0 +1,16 @@ +{ + "problemMatcher": [ + { + "owner": "mypy", + "pattern": [ + { + "regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$", + "file": 1, + "line": 2, + "severity": 3, + "message": 4 + } + ] + } + ] +} diff --git a/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py b/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py index 6efdca80..001ccb7b 100644 --- a/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py +++ b/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py @@ -14,7 +14,7 @@ """Implement FMS adapter for INT8xINT8 checkpoints""" # Standard -from typing import Mapping +from typing import Mapping, MutableMapping # Third Party from fms.utils import serialization @@ -47,7 +47,7 @@ def _int8_qparams_aiu( def _add_defaults_and_concat( - new_sd: dict[str, torch.Tensor], + new_sd: MutableMapping[str, torch.Tensor], modules_seen: set[str], ) -> None: """ diff --git a/fms_mo/calib.py b/fms_mo/calib.py index cff4c67a..9a802894 100644 --- a/fms_mo/calib.py +++ b/fms_mo/calib.py @@ -482,7 +482,8 @@ def qmodel_calib( return model DPorDDPdevices = None - if "qmodel_prep" not in sys._getframe().f_back.f_code.co_name: + f_back = sys._getframe().f_back + if f_back and "qmodel_prep" not in f_back.f_code.co_name: model.to(currDev) qcfg["wasDPmodel"] = qcfg.get("wasDPmodel", isinstance(model, nn.DataParallel)) qcfg["wasDDPmodel"] = qcfg.get( diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index b3c60c4d..21bb52a2 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -74,10 +74,10 @@ # Third Party import torch.library as lib - reg_op = partial(lib.custom_op, mutates_args=()) + reg_op = partial(lib.custom_op, mutates_args=()) # type: ignore[attr-defined] reg_op_func = lib.define # NOTE this is func, not decorator - kernel_impl = lib.register_kernel - reg_fake = lib.register_fake + kernel_impl = lib.register_kernel # type: ignore[attr-defined] + reg_fake = lib.register_fake # type: ignore[attr-defined] else: raise RuntimeError("Custom Op registration only works for >PT2.1") diff --git a/fms_mo/quant/ptq.py b/fms_mo/quant/ptq.py index 2e192a25..c7d830b8 100644 --- a/fms_mo/quant/ptq.py +++ b/fms_mo/quant/ptq.py @@ -2631,8 +2631,10 @@ def reset_bn(module: nn.BatchNorm2d): Function not currently used. """ if module.track_running_stats: - module.running_mean.zero_() - module.running_var.fill_(1 - module.eps) + if running_mean := module.running_mean: + running_mean.zero_() + if running_var := module.running_var: + running_var.fill_(1 - module.eps) # we do not reset numer of tracked batches here if module.affine: nn.init.ones_(module.weight) @@ -2651,7 +2653,7 @@ def reset_bn(module: nn.BatchNorm2d): bn_affine = True # FrozenBN doesn't have .affine property except: BNofInteret = (nn.BatchNorm2d, nn.BatchNorm1d) - AbsorbLayers = (nn.Conv2d, nn.Linear) + AbsorbLayers = (nn.Conv2d, nn.Linear) # type: ignore[assignment] def search_fold_and_remove_bn(model, mod_folded): diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index c97dbfa8..c509ea01 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -23,6 +23,7 @@ """ # pylint: disable=too-many-return-statements +# mypy: disable-error-code="assignment" # Standard from collections.abc import Mapping @@ -3895,7 +3896,8 @@ def forward(self, x: torch.Tensor): self.delta = torch.nn.Parameter(delta) else: delta, zero_point = self.init_quantization_scale(x, self.channel_wise) - self.delta.fill_(delta) + if self_data := self.delta: + self_data.fill_(delta) self.zero_point.fill_(zero_point) self.inited = True @@ -3906,7 +3908,8 @@ def forward(self, x: torch.Tensor): return x_dequant def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False): - delta, zero_point = None, None + # delta, zero_point = 1.0, 0 + # init seems unnecessary, comment out to avoid None induced type chk err if channel_wise: x_clone = x.clone().detach() n_channels = x_clone.shape[0] @@ -3935,7 +3938,7 @@ def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False): x_min = x_min * (self.n_bits + 2) / 8 x_max = x_max * (self.n_bits + 2) / 8 - x_absmax = max(abs(x_min), x_max) + x_absmax = max(abs(x_min), x_max) # type: ignore [call-overload] if self.sym: x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax @@ -3960,7 +3963,7 @@ def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False): if score < best_score: best_score = score delta = (new_max - new_min) / (2**self.n_bits - 1) - zero_point = (-new_min / delta).round() + zero_point = (-new_min / delta).round() # type: ignore[union-attr] else: raise NotImplementedError @@ -4035,8 +4038,8 @@ def __init__( self.reset_ReSig_param(multimodal) self.beta = 2 / 3 - self.Wshape = None - self.reshape2 = None + self.Wshape: list[int] = list() + self.reshape2: list[Any] = list() def forward(self, x): if self.useSAWB: @@ -4583,7 +4586,7 @@ def transformers_prepare_input( if isinstance(data, Mapping): return type(data)( {k: transformers_prepare_input(v, dev=dev) for k, v in data.items()} - ) + ) # type: ignore[call-arg] if isinstance(data, (tuple, list)): return type(data)(transformers_prepare_input(v, dev=dev) for v in data) if isinstance(data, torch.Tensor): @@ -5389,7 +5392,7 @@ def __init__( if "e4m3" in q_mode: self.float8_dtype = torch.float8_e4m3fn elif "e5m2" in q_mode: - self.float8_dtype = torch.float8_e5m2G + self.float8_dtype = torch.float8_e5m2 else: raise ValueError("FP8 only supports e4m3 and e5m2") self.emulate = emulate @@ -5451,7 +5454,7 @@ def custom_fp8_quantizer( mantissa_bits: int = 3, use_subnormal: bool = False, scale_to_max: bool = False, -) -> torch.Tensor: +): """Convert tensor tensor to FP8 format, remanining in decimal form (no binary conversion) and using some clever manipulation to round each tensor values to the closest representable FP8 value. diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index caafec16..81fc1faa 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -15,7 +15,7 @@ # Standard from pathlib import Path -from typing import Any +from typing import Any, Dict import json import logging import os @@ -149,7 +149,7 @@ def qconfig_init(recipe: str = None, args: Any = None): otherwise use constantLR as default """ - qcfg = {} + qcfg: Dict[str, Any] = {} # 1. create a dict with default values qcfg["mapping"] = { nn.Conv2d: {"from": nn.Conv2d, "to": QConv2d, "otherwise": QConv2d}, diff --git a/fms_mo/utils/torchscript_utils.py b/fms_mo/utils/torchscript_utils.py index 39025b8d..0be4083c 100644 --- a/fms_mo/utils/torchscript_utils.py +++ b/fms_mo/utils/torchscript_utils.py @@ -55,7 +55,8 @@ def parse_operation(op_str: str): operands = op_str[ last_open_parenthesis_index + 1 : last_close_parenthesis_index ].split(",") - operands = [operand.strip() for operand in operands] if operands != [""] else None + # pylint: disable=line-too-long + operands = [operand.strip() for operand in operands] if operands != [""] else None # type: ignore[assignment] return operator, operands @@ -178,9 +179,14 @@ def __init__(self, node_input, dictionary_of_nodes: dict): ) operator, operands = parse_operation(op_str) if "aten::_conv" in op_str: - self.ch_in = list(native_torchscript_node.inputs())[0].type().sizes() - # NOTE: Needed for finding shortcut convolutions later - self.ch_out = list(native_torchscript_node.outputs())[0].type().sizes() + if native_torchscript_node: + self.ch_in = ( + list(native_torchscript_node.inputs())[0].type().sizes() + ) + # NOTE: Needed for finding shortcut convolutions later + self.ch_out = ( + list(native_torchscript_node.outputs())[0].type().sizes() + ) else: node_def = node_input_repr op_str, operator, operands = None, None, None @@ -200,31 +206,34 @@ def __init__(self, node_input, dictionary_of_nodes: dict): working_str = node_input_repr[start_index:end_index] start_index = end_index + 2 - node_instance.name, node_instance.obj = working_str.split(" : ") - node_instance.name = node_instance.name.strip() + # pylint: disable=line-too-long + node_instance.name, node_instance.obj = working_str.split(" : ") # type: ignore[attr-defined] + node_instance.name = node_instance.name.strip() # type: ignore[attr-defined] if native_torchscript_outputs: - if node_instance.name not in native_torchscript_outputs: + # pylint: disable=line-too-long + if node_instance.name not in native_torchscript_outputs: # type: ignore[attr-defined] + # pylint: disable=line-too-long logger.error( - f"Node def {node_instance.name} not in nativeTSoutputs " + f"Node def {node_instance.name} not in nativeTSoutputs " # type: ignore[attr-defined] f"{native_torchscript_outputs}" ) - node_instance.Op = op_str + node_instance.Op = op_str # type: ignore[attr-defined] if node_def_in_one_line > 1: - node_instance.unpackIdx = node_index + node_instance.unpackIdx = node_index # type: ignore[attr-defined] if line_number: - node_instance.lineno = line_number - node_instance.operator = operator + node_instance.lineno = line_number # type: ignore[attr-defined] + node_instance.operator = operator # type: ignore[attr-defined] # This is the name of parents, not the pointer to the parent nodes - node_instance.parents = operands - node_instance.parents_ptr = [] - node_instance.scope = scope_repr - node_instance.modname = module_name - node_instance.children = [] - node_instance.children_ptr = [] - node_instance.TSparents = native_torchscript_parents - node_instance.TSoutputs = native_torchscript_outputs + node_instance.parents = operands # type: ignore[attr-defined] + node_instance.parents_ptr = [] # type: ignore[attr-defined] + node_instance.scope = scope_repr # type: ignore[attr-defined] + node_instance.modname = module_name # type: ignore[attr-defined] + node_instance.children = [] # type: ignore[attr-defined] + node_instance.children_ptr = [] # type: ignore[attr-defined] + node_instance.TSparents = native_torchscript_parents # type: ignore[attr-defined] + node_instance.TSoutputs = native_torchscript_outputs # type: ignore[attr-defined] # graph.dictionary_of_nodes will keep a record of all the nodes - dictionary_of_nodes[node_instance.name] = node_instance + dictionary_of_nodes[node_instance.name] = node_instance # type: ignore[attr-defined] def __repr__(self): return f"{self.name} " diff --git a/pyproject.toml b/pyproject.toml index a52bc449..6ad0a81f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,7 @@ known-local-folder=["fms_mo","tests"] [tool.mypy] mypy_path = [""] packages = ["fms_mo", "tests"] -disable_error_code = [] +disable_error_code = ["import-not-found", "import-untyped", "no-any-return"] # TODO: tighten MyPy checks by enabling these checks over time. check_untyped_defs = false disallow_incomplete_defs = false