From a60a4b87e879a4e41b6bf4e5f27020d27f1491be Mon Sep 17 00:00:00 2001 From: cliu-us Date: Tue, 8 Jul 2025 14:59:54 +0000 Subject: [PATCH 1/3] add a new hook for checking post-quant in/out Signed-off-by: cliu-us --- fms_mo/quant/ptq.py | 113 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/fms_mo/quant/ptq.py b/fms_mo/quant/ptq.py index 7801284b..de2c5729 100644 --- a/fms_mo/quant/ptq.py +++ b/fms_mo/quant/ptq.py @@ -42,6 +42,7 @@ # Local from fms_mo.modules import QBmm, QLinear from fms_mo.modules.conv import QConv2dPTQv2 +from fms_mo.modules.linear import LinearFPxAcc, QLinearINT8Deploy from fms_mo.quant.quantizers import ( AdaRoundQuantizer, Qdynamic, @@ -481,8 +482,118 @@ def __call__(self, mod, inputs, *args, **_kwargs): assert not self.stop_after_rec -# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module +class HookRecPostQuantInOut(torch.nn.Module): + """Another simplified hook to check post-quantized input/output, e.g. within +-127 for INT8.""" + + def __init__(self, cache_dict={}, mod_name=None): + super().__init__() + self.cache_dict = cache_dict + self.mod_name = mod_name + name_split = mod_name.split(".") + self.lay_idx = int(name_split[3]) + self.lay_key = name_split[6] + + self.cache_dev = "cpu" + # prepare empty dict for later use + self.cache_dict[mod_name] = {} + self.fwd_mapping = { + LinearFPxAcc: self.call_func_for_fpxacc, + QLinear: self.call_func_for_qlinear, + QLinearINT8Deploy: self.call_func_for_qlinear_int, + torch.nn.Linear: self.call_func_for_nnlinear, + } + + def call_func_for_fpxacc(self, mod, inputs, outputs, **_kwargs): + raise NotImplementedError + + def call_func_for_qlinear(self, mod, inputs, outputs, **_kwargs): + lay_idx = self.lay_idx + lay_key = self.lay_key + mod_name = self.mod_name + cache_dict = self.cache_dict + + act_max = inputs[0].abs().amax(dim=[d for d in range(len(inputs[0].shape) - 1)]) + # mod.smoothq_act_scale + w_max = mod.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5) + is_smq_layer = not torch.all(act_max == 0).item() + # smoothQ scale = smoothq_act_scale**alpha / weight_scale**(1.0 - alpha) + # smoothq_scale = mod.get_smoothq_scale(inputs[0]) + smoothq_scale = getattr(mod, "smq_scale", 1.0) + # "smq_scale" only available in QLin_INT8 + + with torch.no_grad(): + smoothed_inp = inputs[0] / smoothq_scale + smoothed_w = mod.weight * smoothq_scale + + # this is assuming pertokenmax quantizer, NOTE calc quant scale after smoothing + absmax = smoothed_inp.abs().max(dim=-1, keepdim=True)[0] + qa_scale = absmax.clamp(min=1e-5) / 127 + qinput = torch.round(smoothed_inp / qa_scale).clamp(-127, 127) + # should clamp to -128? + if mod.qa_mode == "pertokenmax": + # doesnt implement dequant=False yet, do it manually + cva = mod.quantize_feature.clip_val + qa_scale = cva.clamp(min=1e-5).div(127) + qinput = smoothed_inp.div(qa_scale).round() + else: + mod.quantize_feature.dequantize = False + qinput = mod.quantize_feature(smoothed_inp) + mod.quantize_feature.dequantize = True + + # also record quantized, smoothed W in INT8, calc both maxperCh and SAWBperCh + cvw = mod.quantize_weight.clip_val + scale_w = cvw / 127 + mod.quantize_weight.dequantize = False + qw = mod.quantize_weight(smoothed_w) + mod.quantize_weight.dequantize = True + + # inputs is a tuple, QLinear only has 1 valid input + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["cva"] = cva.to(self.cache_dev) + cache_dict[mod_name]["cvw"] = cvw.to(self.cache_dev) + cache_dict[mod_name]["smoothed_input"] = smoothed_inp.to(self.cache_dev) + cache_dict[mod_name]["smoothed_weight"] = smoothed_w.to(self.cache_dev) + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) + # NOTE in INT8, *scales if need dQ + cache_dict[mod_name]["qweight"] = qw.to(self.cache_dev) + # torch.round(smoothed_w.T/scale_w).clamp(-127, 127).to(self.cache_dev) + # cache_dict[mod_name]["qoutput"] = outputs.to(self.cache_dev) + + def call_func_for_qlinear_int(self, mod, inputs, outputs, **_kwargs): + smoothq_scale = getattr(mod, "smq_scale", 1.0) + mod_name = self.mod_name + cache_dict = self.cache_dict + with torch.no_grad(): + if mod.useDynMaxQfunc in [-1, -2]: + qinput = mod.qa_dynamic_max_qfunc(inputs[0]) + elif mod.use_fake_zero_shift: + qinput = mod.qa_dyn_max_fake_zero_shift(inputs[0]) + elif mod.usePTnativeQfunc: + qinput = mod.qa_raw_qfunc(inputs[0]) + else: + qinput = mod.qa_fmo_mo_qfunc(inputs[0]) + + # inputs is a tuple, QLinear only has 1 valid input + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["cva"] = mod.cvs[0].to(self.cache_dev) + cache_dict[mod_name]["cvw"] = mod.cvs[2].to(self.cache_dev) + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) + cache_dict[mod_name]["qweight"] = mod.weight.to(self.cache_dev) + + def call_func_for_nnlinear(self, mod, inputs, outputs, **_kwargs): + mod_name = self.mod_name + cache_dict = self.cache_dict + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["weight"] = mod.weight.to(self.cache_dev) + + def __call__(self, mod, inputs, outputs, **_kwargs): + self.fwd_mapping[type(mod)](mod, inputs, outputs, **_kwargs) + + class PTQHookRecQOut(nn.Module): + """This hook is for ptq_loss_func == 'fisher_diag' and will temporarily hold the "Q_out" of the + module""" + def __init__(self, qcfg): super().__init__() self.qcfg = qcfg From 04e4cb12e87cd6f88fc66fb43647496ef846220c Mon Sep 17 00:00:00 2001 From: cliu-us Date: Thu, 10 Jul 2025 15:36:36 +0000 Subject: [PATCH 2/3] qkvsync bug fix, graph breaks will induce qkv sibling list error. only partial names will be found and cause problems Signed-off-by: cliu-us --- fms_mo/fx/dynamo_utils.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/fms_mo/fx/dynamo_utils.py b/fms_mo/fx/dynamo_utils.py index 62967ec2..23e01d23 100644 --- a/fms_mo/fx/dynamo_utils.py +++ b/fms_mo/fx/dynamo_utils.py @@ -1010,6 +1010,8 @@ def cus_backend_model_analyzer( } prefix = None if qcfg["N_backend_called"] > 1: # subgraph found, see Note 2 + # TODO this approach only works for FX IR (call_module nodes are not functionalized) + # need an update for Aten IR cases for n in gm_fx.graph.nodes: if n.op == "call_module": mod = gm_fx.get_submodule(n.target) @@ -1220,6 +1222,38 @@ def call_seq_hook(mod, *_args, **_kwargs): # ------ model analysis is finished, but there are a few remaining things to be done # a) qkvsync dict update from "module names" to "module instances" + # NOTE when graph break happened, qkvsync() may only find partial QKV names. For example, + # as opposed to ["model.layers.0.self_attn.q_proj", ..., "model.layers.1.self_attn.q_proj", ...] + # it may report ["self_attn.q_proj", "self_attn.k_proj", ...] + # Therefore, length of qcfg["qkvsync_my_1st_sibling"] will be much shorter and keys of this dict + # won't exist in full list (like all_linears below). + all_linears = set( + [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)] + ) + + if any(k not in all_linears for k in qcfg["qkvsync_my_1st_sibling"]): + # qcfg["qkvsync_my_1st_sibling"] dict is like {q:q, k:q, v:q,...}, here we need a simpler + # dict like {q:[q,k,v], gate:[up, gate]} + lut_all_siblings = {} + for me, sib_1st in qcfg["qkvsync_my_1st_sibling"].items(): + if sib_1st not in lut_all_siblings: + lut_all_siblings[sib_1st] = [sib_1st] + elif me not in lut_all_siblings[sib_1st]: + lut_all_siblings[sib_1st].append(me) + + full_sib_list = {} + for me in lut_all_siblings: + partial_matches = [lin for lin in all_linears if me in lin] + all_sibs = lut_all_siblings[me] + # here lin is full_name, me and all_sibs are partial + for lin in partial_matches: + prefix = lin[: lin.index(me)] + for sib in all_sibs: + full_sib_list[prefix + sib] = prefix + me + all_linears.remove(prefix + sib) + # all_linears will still have down_proj, out_proj, lm_head, and maybe others + qcfg["qkvsync_my_1st_sibling"] = full_sib_list + updated_dict = { model.get_submodule(mod): model.get_submodule(sib) for mod, sib in qcfg["qkvsync_my_1st_sibling"].items() @@ -1303,7 +1337,7 @@ def qbmm_auto_check(_mod, *_args, **_kwargs): if qcfg["N_backend_called"] > 1: logger.warning( f"Found {qcfg['N_backend_called']} graph breaks during Dynamo tracing!! \n" - f"First/Last layer, which usually needs to stay unquantized, cannot be identified" + f"First/Last layer, which usually needs to stay unquantized, may not be identified" f" correctly now. Please double-check layers being skipped:\n" f"{qcfg['qskip_layer_name']}\n NOTE: Users can control layer selection by adding layer" f"names to:\n" From d6fd5538aa66dd6fb4d26662f2a67a80cc7fd0e9 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Thu, 10 Jul 2025 16:46:36 +0000 Subject: [PATCH 3/3] linting fix Signed-off-by: cliu-us --- fms_mo/fx/dynamo_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fms_mo/fx/dynamo_utils.py b/fms_mo/fx/dynamo_utils.py index 23e01d23..247eb72b 100644 --- a/fms_mo/fx/dynamo_utils.py +++ b/fms_mo/fx/dynamo_utils.py @@ -1228,7 +1228,7 @@ def call_seq_hook(mod, *_args, **_kwargs): # Therefore, length of qcfg["qkvsync_my_1st_sibling"] will be much shorter and keys of this dict # won't exist in full list (like all_linears below). all_linears = set( - [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)] + n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) ) if any(k not in all_linears for k in qcfg["qkvsync_my_1st_sibling"]): @@ -1242,9 +1242,8 @@ def call_seq_hook(mod, *_args, **_kwargs): lut_all_siblings[sib_1st].append(me) full_sib_list = {} - for me in lut_all_siblings: + for me, all_sibs in lut_all_siblings.items(): partial_matches = [lin for lin in all_linears if me in lin] - all_sibs = lut_all_siblings[me] # here lin is full_name, me and all_sibs are partial for lin in partial_matches: prefix = lin[: lin.index(me)]