Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion fms_mo/fx/dynamo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,8 @@ def cus_backend_model_analyzer(
}
prefix = None
if qcfg["N_backend_called"] > 1: # subgraph found, see Note 2
# TODO this approach only works for FX IR (call_module nodes are not functionalized)
# need an update for Aten IR cases
for n in gm_fx.graph.nodes:
if n.op == "call_module":
mod = gm_fx.get_submodule(n.target)
Expand Down Expand Up @@ -1220,6 +1222,37 @@ def call_seq_hook(mod, *_args, **_kwargs):
# ------ model analysis is finished, but there are a few remaining things to be done

# a) qkvsync dict update from "module names" to "module instances"
# NOTE when graph break happened, qkvsync() may only find partial QKV names. For example,
# as opposed to ["model.layers.0.self_attn.q_proj", ..., "model.layers.1.self_attn.q_proj", ...]
# it may report ["self_attn.q_proj", "self_attn.k_proj", ...]
# Therefore, length of qcfg["qkvsync_my_1st_sibling"] will be much shorter and keys of this dict
# won't exist in full list (like all_linears below).
all_linears = set(
n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)
)

if any(k not in all_linears for k in qcfg["qkvsync_my_1st_sibling"]):
# qcfg["qkvsync_my_1st_sibling"] dict is like {q:q, k:q, v:q,...}, here we need a simpler
# dict like {q:[q,k,v], gate:[up, gate]}
lut_all_siblings = {}
for me, sib_1st in qcfg["qkvsync_my_1st_sibling"].items():
if sib_1st not in lut_all_siblings:
lut_all_siblings[sib_1st] = [sib_1st]
elif me not in lut_all_siblings[sib_1st]:
lut_all_siblings[sib_1st].append(me)

full_sib_list = {}
for me, all_sibs in lut_all_siblings.items():
partial_matches = [lin for lin in all_linears if me in lin]
# here lin is full_name, me and all_sibs are partial
for lin in partial_matches:
prefix = lin[: lin.index(me)]
for sib in all_sibs:
full_sib_list[prefix + sib] = prefix + me
all_linears.remove(prefix + sib)
# all_linears will still have down_proj, out_proj, lm_head, and maybe others
qcfg["qkvsync_my_1st_sibling"] = full_sib_list

updated_dict = {
model.get_submodule(mod): model.get_submodule(sib)
for mod, sib in qcfg["qkvsync_my_1st_sibling"].items()
Expand Down Expand Up @@ -1303,7 +1336,7 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
if qcfg["N_backend_called"] > 1:
logger.warning(
f"Found {qcfg['N_backend_called']} graph breaks during Dynamo tracing!! \n"
f"First/Last layer, which usually needs to stay unquantized, cannot be identified"
f"First/Last layer, which usually needs to stay unquantized, may not be identified"
f" correctly now. Please double-check layers being skipped:\n"
f"{qcfg['qskip_layer_name']}\n NOTE: Users can control layer selection by adding layer"
f"names to:\n"
Expand Down
113 changes: 112 additions & 1 deletion fms_mo/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
# Local
from fms_mo.modules import QBmm, QLinear
from fms_mo.modules.conv import QConv2dPTQv2
from fms_mo.modules.linear import LinearFPxAcc, QLinearINT8Deploy
from fms_mo.quant.quantizers import (
AdaRoundQuantizer,
Qdynamic,
Expand Down Expand Up @@ -481,8 +482,118 @@ def __call__(self, mod, inputs, *args, **_kwargs):
assert not self.stop_after_rec


# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module
class HookRecPostQuantInOut(torch.nn.Module):
"""Another simplified hook to check post-quantized input/output, e.g. within +-127 for INT8."""

def __init__(self, cache_dict={}, mod_name=None):
super().__init__()
self.cache_dict = cache_dict
self.mod_name = mod_name
name_split = mod_name.split(".")
self.lay_idx = int(name_split[3])
self.lay_key = name_split[6]

self.cache_dev = "cpu"
# prepare empty dict for later use
self.cache_dict[mod_name] = {}
self.fwd_mapping = {
LinearFPxAcc: self.call_func_for_fpxacc,
QLinear: self.call_func_for_qlinear,
QLinearINT8Deploy: self.call_func_for_qlinear_int,
torch.nn.Linear: self.call_func_for_nnlinear,
}

def call_func_for_fpxacc(self, mod, inputs, outputs, **_kwargs):
raise NotImplementedError

def call_func_for_qlinear(self, mod, inputs, outputs, **_kwargs):
lay_idx = self.lay_idx
lay_key = self.lay_key
mod_name = self.mod_name
cache_dict = self.cache_dict

act_max = inputs[0].abs().amax(dim=[d for d in range(len(inputs[0].shape) - 1)])
# mod.smoothq_act_scale
w_max = mod.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5)
is_smq_layer = not torch.all(act_max == 0).item()
# smoothQ scale = smoothq_act_scale**alpha / weight_scale**(1.0 - alpha)
# smoothq_scale = mod.get_smoothq_scale(inputs[0])
smoothq_scale = getattr(mod, "smq_scale", 1.0)
# "smq_scale" only available in QLin_INT8

with torch.no_grad():
smoothed_inp = inputs[0] / smoothq_scale
smoothed_w = mod.weight * smoothq_scale

# this is assuming pertokenmax quantizer, NOTE calc quant scale after smoothing
absmax = smoothed_inp.abs().max(dim=-1, keepdim=True)[0]
qa_scale = absmax.clamp(min=1e-5) / 127
qinput = torch.round(smoothed_inp / qa_scale).clamp(-127, 127)
# should clamp to -128?
if mod.qa_mode == "pertokenmax":
# doesnt implement dequant=False yet, do it manually
cva = mod.quantize_feature.clip_val
qa_scale = cva.clamp(min=1e-5).div(127)
qinput = smoothed_inp.div(qa_scale).round()
else:
mod.quantize_feature.dequantize = False
qinput = mod.quantize_feature(smoothed_inp)
mod.quantize_feature.dequantize = True

# also record quantized, smoothed W in INT8, calc both maxperCh and SAWBperCh
cvw = mod.quantize_weight.clip_val
scale_w = cvw / 127
mod.quantize_weight.dequantize = False
qw = mod.quantize_weight(smoothed_w)
mod.quantize_weight.dequantize = True

# inputs is a tuple, QLinear only has 1 valid input
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
cache_dict[mod_name]["cva"] = cva.to(self.cache_dev)
cache_dict[mod_name]["cvw"] = cvw.to(self.cache_dev)
cache_dict[mod_name]["smoothed_input"] = smoothed_inp.to(self.cache_dev)
cache_dict[mod_name]["smoothed_weight"] = smoothed_w.to(self.cache_dev)
cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev)
# NOTE in INT8, *scales if need dQ
cache_dict[mod_name]["qweight"] = qw.to(self.cache_dev)
# torch.round(smoothed_w.T/scale_w).clamp(-127, 127).to(self.cache_dev)
# cache_dict[mod_name]["qoutput"] = outputs.to(self.cache_dev)

def call_func_for_qlinear_int(self, mod, inputs, outputs, **_kwargs):
smoothq_scale = getattr(mod, "smq_scale", 1.0)
mod_name = self.mod_name
cache_dict = self.cache_dict
with torch.no_grad():
if mod.useDynMaxQfunc in [-1, -2]:
qinput = mod.qa_dynamic_max_qfunc(inputs[0])
elif mod.use_fake_zero_shift:
qinput = mod.qa_dyn_max_fake_zero_shift(inputs[0])
elif mod.usePTnativeQfunc:
qinput = mod.qa_raw_qfunc(inputs[0])
else:
qinput = mod.qa_fmo_mo_qfunc(inputs[0])

# inputs is a tuple, QLinear only has 1 valid input
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
cache_dict[mod_name]["cva"] = mod.cvs[0].to(self.cache_dev)
cache_dict[mod_name]["cvw"] = mod.cvs[2].to(self.cache_dev)
cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev)
cache_dict[mod_name]["qweight"] = mod.weight.to(self.cache_dev)

def call_func_for_nnlinear(self, mod, inputs, outputs, **_kwargs):
mod_name = self.mod_name
cache_dict = self.cache_dict
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
cache_dict[mod_name]["weight"] = mod.weight.to(self.cache_dev)

def __call__(self, mod, inputs, outputs, **_kwargs):
self.fwd_mapping[type(mod)](mod, inputs, outputs, **_kwargs)


class PTQHookRecQOut(nn.Module):
"""This hook is for ptq_loss_func == 'fisher_diag' and will temporarily hold the "Q_out" of the
module"""

def __init__(self, qcfg):
super().__init__()
self.qcfg = qcfg
Expand Down
Loading