diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index beae5a2e..bcba22ca 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -114,6 +114,7 @@ def matmul_kernel( stride_cn, chunk_trun_bits, max_acc_bits, # pylint: disable=unused-argument + clamp_acc_to_dl16, truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -159,13 +160,8 @@ def matmul_kernel( # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - ## ------ prepare LSB rounding/truncation masks ------- - # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b - # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 - # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 - trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32) - round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 - ## --------------------------------------------------------- + ## ------ prepare LSB rounding/truncation masks outside the for loop ------- + round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. @@ -180,8 +176,10 @@ def matmul_kernel( # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp ## ------ add chunky LSB rounding/masking -------- - if chunk_trun_bits > 0: - accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) + if clamp_acc_to_dl16 or chunk_trun_bits > 0: + accumulator_inner = round_and_trun( + accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16 + ) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner @@ -226,6 +224,7 @@ def imatmul_kernel( stride_cn, chunk_trun_bits, max_acc_bits, + clamp_acc_to_dl16, # pylint: disable=unused-argument truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -324,6 +323,7 @@ def matmul_kernel_DABC( stride_cn, chunk_trun_bits, max_acc_bits, # pylint: disable=unused-argument + clamp_acc_to_dl16, truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -377,13 +377,8 @@ def matmul_kernel_DABC( # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0) - ## ------ prepare LSB rounding/truncation masks ------- - # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b - # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 - # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 - trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32) - round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 - ## --------------------------------------------------------- + ## ------ prepare LSB rounding/truncation masks outside the for loop ------- + round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A, B, and C, generate a mask by checking the K dimension. @@ -403,8 +398,10 @@ def matmul_kernel_DABC( # precision as well, hence, could lose some precision! ## ------ add chunky LSB rounding/masking -------- - if chunk_trun_bits > 0: - accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) + if clamp_acc_to_dl16 or chunk_trun_bits > 0: + accumulator_inner = round_and_trun( + accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16 + ) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner @@ -433,9 +430,39 @@ def leaky_relu(x): @triton.jit -def round_and_trun(x, round_bit, trun_mask): +def round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16): + """ + Rounding and LSB truncation masks only need to be generated once. + These mask will be applied on "inner" accumulator, which is alway FP32 (e8m23). We may truncate + up to 23b for mantissa. If DL16/DL8, pay attention to exponent bias. + Examples: 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 + 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 + """ + if clamp_acc_to_dl16: + # DL16 is e6m9, hence, truncate 23 - 9 = 14 bits + chunk_trun_bits = 14 + round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 + trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32) + return round_bit, trun_mask + + +@triton.jit +def round_and_trun(x, round_bit, trun_mask, clamp_acc_to_dl16): """Round and truncate (usually for accumulator).""" - return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask) + x = libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask) + + if clamp_acc_to_dl16: + # clamp to DL16 min/max: + # max = 2^32 * 1.(1111 1111 0)_base2 = 2^32*(2 - 2^-9) = 8581545984.0 + # greater than this will become +inf (or -inf) + # min = 2^-31 * 1.(0000 0000 1)_base2 = 2^-31*(1 + 2^-9)> = 4.665707820095122e-10 + # smaller than this will become 0 + dl16_max = 8581545984.0 + dl16_min = 4.665707820095122e-10 + x = tl.where(x >= dl16_max, float("inf"), x) + x = tl.where(x <= -dl16_max, float("-inf"), x) + x = tl.where(tl.abs(x) < dl16_min, 0, x) + return x def tl_matmul_chunk_truncate( @@ -448,6 +475,7 @@ def tl_matmul_chunk_truncate( max_acc_bits=32, truncate_then_accumulate=True, cast_output_to_input_dtype=None, + clamp_acc_to_dl16=False, ): """Triton matmul for HW behavior simulation. Supports float and int8. i. variable chunk size (i.e., BLOCK_SIZE_K) @@ -461,7 +489,8 @@ def tl_matmul_chunk_truncate( chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16. max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will clamp each chunk of a*b to [-2**23-1, 2**23]. - (assuming no inf when overflow) + (only used by INT) + clamp_acc_to_dl16(bool): Only used by FP8, whether to clamp local accumulator (FP32) to DL16 truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise c = truncate(a*b+c) cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually @@ -473,7 +502,7 @@ def tl_matmul_chunk_truncate( NOTE: use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for - real model inference. otherwise auto-tune will be triggered in every forward call. + real model inference. otherwise auto-tune may be triggered in every forward call. """ # Check constraints. @@ -584,6 +613,7 @@ def grid(META): c.stride(1), chunk_trun_bits=chunk_trun_bits, max_acc_bits=max_acc_bits, + clamp_acc_to_dl16=clamp_acc_to_dl16, truncate_then_accumulate=truncate_then_accumulate, ACTIVATION=activation, **kernel_config, # if using auto-tune, comment this line out. diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index 5ab78f8b..386b8c14 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -870,14 +870,16 @@ def lower_qmodel_triton( model: torch.nn.Module, use_dyn_max_act=False, max_acc_bits=32, + clamp_acc_to_dl16=False, num_lsb_to_truncate=0, chunk_size=32, + layer_to_exclude=None, ): """ - Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else. + Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers. Triton kernel can be used to: 1. test INT8 or FP8 HW performance (kernel is not optimized) - 2. simulate MSB/LSB truncation effect + 2. simulate MSB/LSB truncation effect or special dtype (DL16) accumulation Args: model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy @@ -888,6 +890,8 @@ def lower_qmodel_triton( efficiency at the expense of higher chance of accumulation "overflow". For example, an INT24 accumulator can only hold values ranged from -2^23 to 2^23 -1, as opposed to typical range -2^31 to -2^31 -1. + clamp_acc_to_dl16: clamp local accumulator to DL16 (1-6-9) range. To simulate this special + dtype effect on accumulation. num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is s1e8m23, if we choose to truncate 13 mantissa bits from right most side, i.e. LSB, the resulting number will be s1e8m10, which is TF32. @@ -900,25 +904,56 @@ def lower_qmodel_triton( from torch.ao.quantization.utils import _parent_name # Local - from fms_mo.modules.linear import QLinear, QLinearINT8Deploy + from fms_mo.modules.linear import LinearFPxAcc, QLinear, QLinearINT8Deploy + + # Currently QLinearINT8 has more options in dynamic quantization than LinearFP. Here we resolve + # the differences as a patch solution (will unify the codes in future release) + linFP_dyn_code = ( + "per_token" + if use_dyn_max_act in [-1, -2] + else "per_tensor" + if use_dyn_max_act + else False + ) + + if layer_to_exclude is None: + layer_to_exclude = [] + elif isinstance(layer_to_exclude, str): + layer_to_exclude = [ + layer_to_exclude, + ] + elif not isinstance(layer_to_exclude, (list, tuple)): + raise RuntimeError("layer_to_exclude has to be either str, list, or tuple.") for name, m in model.named_modules(): - if not isinstance(m, QLinear): + if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude: continue parent_name, module_name = _parent_name(name) parent_mod = model.get_submodule(parent_name) - qmod = getattr(parent_mod, module_name) - setattr( - parent_mod, - module_name, - QLinearINT8Deploy.from_fms_mo( - qmod, + + # Only support simulations of 1) QLinear -> INT8, 2) nnLinear->FP8 for now + if isinstance(m, QLinear): + new_lin = QLinearINT8Deploy.from_fms_mo( + m, use_int_kernel="triton", use_dynamic_max_act_Qfunc=use_dyn_max_act, max_acc_bits=max_acc_bits, truncate_lsb=num_lsb_to_truncate, chunk_size=chunk_size, - ), + ) + else: + new_lin = LinearFPxAcc.from_nn( + m, + trun_bits=num_lsb_to_truncate, + chunk_size=chunk_size, + dynamic_fp8=linFP_dyn_code, + clamp_acc_to_dl16=clamp_acc_to_dl16, + ) + + setattr( + parent_mod, + module_name, + new_lin, ) logger.info(f"\nModel lowering with triton kernel is done.\n{model}") diff --git a/fms_mo/dq.py b/fms_mo/dq.py index eef16049..eb49bc30 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -161,6 +161,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): # config layers to skip, smooth scale config_quantize_smooth_layers(qcfg) + use_dynamo = True + # use dynamo as default unless really needed, False -> fallback to TorchScript tracing if any(x != 32 for x in attn_bits): logger.info("Quantize attention bmms or kvcache, will use dynamo for prep") use_layer_name_pattern_matching = False @@ -168,11 +170,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): assert ( qcfg["qlayer_name_pattern"] == [] ), "ensure nothing in qlayer_name_pattern when use dynamo" - use_dynamo = True else: logger.info("Attention bmms will not be quantized.") use_layer_name_pattern_matching = True - use_dynamo = False qcfg["seq_len"] = block_size qcfg["model"] = model_args.model_name_or_path @@ -216,17 +216,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): act_scales = get_act_scales(model, dq_dataloader, qcfg) torch.save(act_scales, scale_file) - qmodel_prep( - model, - dq_dataloader, - qcfg, - use_layer_name_pattern_matching=use_layer_name_pattern_matching, - use_dynamo=use_dynamo, - dev=dev, - save_fname="dq", - ) - logger.info(f"Quantized model {model}") - logger.info("==" * 20) + if fms_mo_args.aiu_sim_triton != "fp8": + qmodel_prep( + model, + dq_dataloader, + qcfg, + use_layer_name_pattern_matching=use_layer_name_pattern_matching, + use_dynamo=use_dynamo, + dev=dev, + save_fname="dq", + ) + logger.info(f"Quantized model {model}") + logger.info("==" * 20) if qcfg["smoothq"]: logger.info("Starting to apply smooth scale") @@ -260,12 +261,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): tokenizer.save_pretrained(opt_args.output_dir) if fms_mo_args.aiu_sim_triton: + # NOTE plz apply correct HW settings here, defaults are not real HW params lower_qmodel_triton( model, use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False, max_acc_bits=qcfg.get("max_acc_bits", 32), num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0), - chunk_size=qcfg.get("chunk_size", 1024), + chunk_size=qcfg.get("chunk_size", 32), # 1024 + clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8", + # layer_to_exclude=["lm_head",] ) if fms_mo_args.eval_ppl: diff --git a/fms_mo/fx/dynamo_utils.py b/fms_mo/fx/dynamo_utils.py index 62967ec2..f7cd61ea 100644 --- a/fms_mo/fx/dynamo_utils.py +++ b/fms_mo/fx/dynamo_utils.py @@ -1180,14 +1180,20 @@ def cus_backend_model_analyzer( if is_transformers: # NOTE simplified method to determine 1st/last modules for transformers. # will not work if model has multiple parallel heads at the end, e.g. obj det - def call_seq_hook(mod, *_args, **_kwargs): - qcfg["mod_call_seq"].append(lut_weight2modname[mod.weight]) + def call_seq_hook(mod, *_args, **kwargs): + mod_name = kwargs.get("mod_name", lut_weight2modname.get(mod.weight, None)) + if mod_name is None: + raise RuntimeError("cannot determine module name, plz check model.") + + qcfg["mod_call_seq"].append(mod_name) h_hooks = [] qcfg["mod_call_seq"] = [] for n, m in model.named_modules(): if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)): - h_hooks.append(m.register_forward_hook(call_seq_hook)) + h_hooks.append( + m.register_forward_hook(partial(call_seq_hook, mod_name=n)) + ) with torch.no_grad(): run_fwd_once(model, sample_inp) diff --git a/fms_mo/fx/utils.py b/fms_mo/fx/utils.py index 357a877b..45abe184 100644 --- a/fms_mo/fx/utils.py +++ b/fms_mo/fx/utils.py @@ -461,14 +461,14 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False): w_mat.numel() * w_mat.element_size() + b_mat.numel() * b_mat.element_size() ) - w_dtype = w_mat.dtype + w_dtype = str(w_mat.dtype) w_shape = w_mat.shape elif isinstance(w, torch.Tensor): mem_use = w.numel() * w.element_size() if hasattr(m, "bias") and m.bias is not None: mem_use += m.bias.numel() * m.bias.element_size() - w_dtype = w.dtype + w_dtype = str(w.dtype) w_shape = w.shape if w_shape: diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index ddb1d14d..3a39bb30 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -1899,7 +1899,16 @@ class LinearFuncFPxFwdBwd(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False): + def forward( + ctx, + x, + weight, + bias=None, + trun_bits=0, + chunk_size=16, + fp8_dyn=False, + clamp_acc_to_dl16=False, + ): assert x.dtype in [torch.float, torch.bfloat16, torch.float16] # input can be 2D or 3D, need to reshape before tl_matmul org_dtype = x.dtype @@ -1916,27 +1925,49 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False ctx.trun_bits = trun_bits ctx.chunk_size = chunk_size ctx.fp8_dyn = fp8_dyn + ctx.clamp_acc_to_dl16 = clamp_acc_to_dl16 + ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max + ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max + ctx.dl8_min = 0.0087890625 + x_scale = torch.tensor(1.0, device=x.device, dtype=org_dtype) + w_scale = x_scale.clone() if fp8_dyn: # use Q/dQ simulation for now, meaning still compute in fp16/bf16 # if choose per_token for input, use per_channel for W # (W saved as [out, in], reduce inCh-dim, => reduce_dim=1) - ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max - ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max reduce_dim = None if fp8_dyn == "per_tensor" else 1 - x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max - w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max - - x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale - weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale + x_scale = ( + x.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max + ).clamp(min=1e-5) + w_scale = ( + weight.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max + ).clamp(min=1e-5) + + x = (x / x_scale).to(torch.float8_e4m3fn).to(torch.float32) + weight = (weight / w_scale).to(torch.float8_e4m3fn).to(torch.float32) + if clamp_acc_to_dl16: + # at this point, x and W are clamped to PT's FP8 range (2^-9 to 448). But since DL8 + # doesn't support subnorm like PyTorch, need to flush subnorms to 0 BEFORE descaling + x.masked_fill_(x.abs() < ctx.dl8_min, 0) + weight.masked_fill_(weight.abs() < ctx.dl8_min, 0) # triton kernel assumes 2D inputs and cast the return to input.dtype - output = tl_matmul( - x, - weight.t().to(org_dtype), - chunk_trun_bits=trun_bits, - chunk_size=chunk_size, - ).reshape(target_shape_output) + output = ( + ( + tl_matmul( + x, + weight.t(), + chunk_trun_bits=trun_bits, + chunk_size=chunk_size, + clamp_acc_to_dl16=clamp_acc_to_dl16, + ) + * x_scale + * w_scale.t() + ) + .to(org_dtype) + .reshape(target_shape_output) + ) if bias is not None: output = output + bias.to(org_dtype) @@ -1956,6 +1987,8 @@ def backward(ctx, grad_output): target_shape_grad_input = grad_output.shape[:-1] + (in_dim,) grad_output_2D = grad_output.reshape(-1, out_dim).to(dtype_input) + x_scale = torch.tensor(1.0, device=x.device, dtype=dtype_input) + w_scale = x_scale.clone() if ctx.fp8_dyn: reduce_dim = None if ctx.fp8_dyn == "per_tensor" else 1 x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max @@ -1963,30 +1996,45 @@ def backward(ctx, grad_output): # always assume perT in this case grad_out_scale = grad_output_2D.abs().amax(dim=None) / ctx.fp8_e5m2_max - x = (x / x_scale).to(torch.float8_e5m2).to(dtype_input) * x_scale - weight = (weight / w_scale).to(torch.float8_e5m2).to(weight.dtype) * w_scale - grad_output_2D = (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to( - grad_output.dtype - ) * grad_out_scale + x = (x / x_scale).to(torch.float8_e5m2).to(torch.float) + weight = (weight / w_scale).to(torch.float8_e5m2).to(torch.float) + grad_output_2D = ( + (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to(torch.float) + ) + if ctx.clamp_acc_to_dl16: + # flush subnorm numbers to 0 as DL8 doesn't support it + x.masked_fill_(x.abs() < ctx.dl8_min, 0) + weight.masked_fill_(weight.abs() < ctx.dl8_min, 0) + grad_output_2D.masked_fill_(grad_output_2D.abs() < ctx.dl8_min, 0) # Compute grad_weight, shape = [out, in] # NOTE: this triton kernel requires A matrix to be contiguous - grad_weight = tl_matmul( - grad_output_2D.transpose(0, 1).contiguous(), - x, - chunk_trun_bits=trun_bits, - chunk_size=chunk_size, - ).to(weight.dtype) - # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D - grad_input = ( + grad_weight = ( tl_matmul( - grad_output_2D, - weight.to(dtype_input), + grad_output_2D.transpose(0, 1).contiguous(), + x, chunk_trun_bits=trun_bits, chunk_size=chunk_size, + clamp_acc_to_dl16=ctx.clamp_acc_to_dl16, + ) + * grad_out_scale.t() + * x_scale + ).to(weight.dtype) + # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D + grad_input = ( + ( + tl_matmul( + grad_output_2D, + weight, + chunk_trun_bits=trun_bits, + chunk_size=chunk_size, + clamp_acc_to_dl16=ctx.clamp_acc_to_dl16, + ) + * grad_out_scale + * w_scale ) - .reshape(target_shape_grad_input) .to(dtype_input) + .reshape(target_shape_grad_input) ) if not ctx.has_bias: @@ -1994,7 +2042,7 @@ def backward(ctx, grad_output): else: grad_bias = grad_output_2D.sum(0).to(ctx.bias_dtype) - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None class LinearFPxAcc(torch.nn.Linear): @@ -2016,6 +2064,10 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs): cls (class): The class to be created. nnlin (torch.nn.Linear): The original torch.nn.Linear module. trun_bits (int): truncate [0 to 22] LSBs from FP32 accumulation. + dynamic_fp8: whether to use dynamic quantization for fp8 activations, available options + are ["per_tensor", "per_token", False] + clamp_acc_to_dl16: clamp local accumulator into DL16 range, to simulate the effect of + this special dtype **kwargs: Additional keyword arguments. Returns: @@ -2030,14 +2082,14 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs): nnlin.in_features, nnlin.out_features, bias=nnlin.bias is not None, - device=target_device, + device="meta", ) lin24acc.weight = nnlin.weight lin24acc.trun_bits = trun_bits lin24acc.chunk_size = kwargs.get("chunk_size", False) lin24acc.fp8_dyn = kwargs.get("dynamic_fp8", False) - # available options are ["per_tensor", "per_token"] + lin24acc.clamp_acc_to_dl16 = kwargs.get("clamp_acc_to_dl16", False) if nnlin.bias is not None: lin24acc.bias = nnlin.bias @@ -2052,16 +2104,24 @@ def forward(self, inputs): self.trun_bits, self.chunk_size, self.fp8_dyn, + self.clamp_acc_to_dl16, ) def extra_repr(self) -> str: """ Returns an alternative string representation of the object. """ - return ( - f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, " - f"trun_bits={self.trun_bits},fp8_dyn={self.fp8_dyn},chunk_size={self.chunk_size}" - ) + repr_str = f"{self.in_features},{self.out_features}" + if self.bias is not None: + repr_str += f",bias={self.bias is not None}" + if self.trun_bits > 0: + repr_str += f",trun_bits={self.trun_bits}" + if self.fp8_dyn: + repr_str += f",fp8_dyn={self.fp8_dyn}" + if self.clamp_acc_to_dl16: + repr_str += ",use_DL16_acc" + repr_str += f",chunk_size={self.chunk_size}" + return repr_str class LinearFuncINT8FwdFP32Bwd(torch.autograd.Function): diff --git a/fms_mo/quant/ptq.py b/fms_mo/quant/ptq.py index 7801284b..de2c5729 100644 --- a/fms_mo/quant/ptq.py +++ b/fms_mo/quant/ptq.py @@ -42,6 +42,7 @@ # Local from fms_mo.modules import QBmm, QLinear from fms_mo.modules.conv import QConv2dPTQv2 +from fms_mo.modules.linear import LinearFPxAcc, QLinearINT8Deploy from fms_mo.quant.quantizers import ( AdaRoundQuantizer, Qdynamic, @@ -481,8 +482,118 @@ def __call__(self, mod, inputs, *args, **_kwargs): assert not self.stop_after_rec -# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module +class HookRecPostQuantInOut(torch.nn.Module): + """Another simplified hook to check post-quantized input/output, e.g. within +-127 for INT8.""" + + def __init__(self, cache_dict={}, mod_name=None): + super().__init__() + self.cache_dict = cache_dict + self.mod_name = mod_name + name_split = mod_name.split(".") + self.lay_idx = int(name_split[3]) + self.lay_key = name_split[6] + + self.cache_dev = "cpu" + # prepare empty dict for later use + self.cache_dict[mod_name] = {} + self.fwd_mapping = { + LinearFPxAcc: self.call_func_for_fpxacc, + QLinear: self.call_func_for_qlinear, + QLinearINT8Deploy: self.call_func_for_qlinear_int, + torch.nn.Linear: self.call_func_for_nnlinear, + } + + def call_func_for_fpxacc(self, mod, inputs, outputs, **_kwargs): + raise NotImplementedError + + def call_func_for_qlinear(self, mod, inputs, outputs, **_kwargs): + lay_idx = self.lay_idx + lay_key = self.lay_key + mod_name = self.mod_name + cache_dict = self.cache_dict + + act_max = inputs[0].abs().amax(dim=[d for d in range(len(inputs[0].shape) - 1)]) + # mod.smoothq_act_scale + w_max = mod.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5) + is_smq_layer = not torch.all(act_max == 0).item() + # smoothQ scale = smoothq_act_scale**alpha / weight_scale**(1.0 - alpha) + # smoothq_scale = mod.get_smoothq_scale(inputs[0]) + smoothq_scale = getattr(mod, "smq_scale", 1.0) + # "smq_scale" only available in QLin_INT8 + + with torch.no_grad(): + smoothed_inp = inputs[0] / smoothq_scale + smoothed_w = mod.weight * smoothq_scale + + # this is assuming pertokenmax quantizer, NOTE calc quant scale after smoothing + absmax = smoothed_inp.abs().max(dim=-1, keepdim=True)[0] + qa_scale = absmax.clamp(min=1e-5) / 127 + qinput = torch.round(smoothed_inp / qa_scale).clamp(-127, 127) + # should clamp to -128? + if mod.qa_mode == "pertokenmax": + # doesnt implement dequant=False yet, do it manually + cva = mod.quantize_feature.clip_val + qa_scale = cva.clamp(min=1e-5).div(127) + qinput = smoothed_inp.div(qa_scale).round() + else: + mod.quantize_feature.dequantize = False + qinput = mod.quantize_feature(smoothed_inp) + mod.quantize_feature.dequantize = True + + # also record quantized, smoothed W in INT8, calc both maxperCh and SAWBperCh + cvw = mod.quantize_weight.clip_val + scale_w = cvw / 127 + mod.quantize_weight.dequantize = False + qw = mod.quantize_weight(smoothed_w) + mod.quantize_weight.dequantize = True + + # inputs is a tuple, QLinear only has 1 valid input + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["cva"] = cva.to(self.cache_dev) + cache_dict[mod_name]["cvw"] = cvw.to(self.cache_dev) + cache_dict[mod_name]["smoothed_input"] = smoothed_inp.to(self.cache_dev) + cache_dict[mod_name]["smoothed_weight"] = smoothed_w.to(self.cache_dev) + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) + # NOTE in INT8, *scales if need dQ + cache_dict[mod_name]["qweight"] = qw.to(self.cache_dev) + # torch.round(smoothed_w.T/scale_w).clamp(-127, 127).to(self.cache_dev) + # cache_dict[mod_name]["qoutput"] = outputs.to(self.cache_dev) + + def call_func_for_qlinear_int(self, mod, inputs, outputs, **_kwargs): + smoothq_scale = getattr(mod, "smq_scale", 1.0) + mod_name = self.mod_name + cache_dict = self.cache_dict + with torch.no_grad(): + if mod.useDynMaxQfunc in [-1, -2]: + qinput = mod.qa_dynamic_max_qfunc(inputs[0]) + elif mod.use_fake_zero_shift: + qinput = mod.qa_dyn_max_fake_zero_shift(inputs[0]) + elif mod.usePTnativeQfunc: + qinput = mod.qa_raw_qfunc(inputs[0]) + else: + qinput = mod.qa_fmo_mo_qfunc(inputs[0]) + + # inputs is a tuple, QLinear only has 1 valid input + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["cva"] = mod.cvs[0].to(self.cache_dev) + cache_dict[mod_name]["cvw"] = mod.cvs[2].to(self.cache_dev) + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) + cache_dict[mod_name]["qweight"] = mod.weight.to(self.cache_dev) + + def call_func_for_nnlinear(self, mod, inputs, outputs, **_kwargs): + mod_name = self.mod_name + cache_dict = self.cache_dict + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["weight"] = mod.weight.to(self.cache_dev) + + def __call__(self, mod, inputs, outputs, **_kwargs): + self.fwd_mapping[type(mod)](mod, inputs, outputs, **_kwargs) + + class PTQHookRecQOut(nn.Module): + """This hook is for ptq_loss_func == 'fisher_diag' and will temporarily hold the "Q_out" of the + module""" + def __init__(self, qcfg): super().__init__() self.qcfg = qcfg diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index 1076b438..661f72bd 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -192,8 +192,15 @@ class FMSMOArguments(TypeChecker): default=2048, metadata={"help": "input sequence length after tokenization"} ) eval_ppl: bool = field(default=False) - aiu_sim_triton: bool = field( - default=False, metadata={"help": ("AIU simulation with triton kernel")} + aiu_sim_triton: Optional[str] = field( + default=None, + metadata={ + "help": ( + "AIU simulation with triton kernel. ['int8', 'fp8', None]\n" + "'int8' mode will trigger qmodel_prep() and swap QLinears" + "'fp8' mode will directly replace existing nn.Linears" + ) + }, ) recompute_narrow_weights: bool = field( default=False, diff --git a/fms_mo/utils/dq_utils.py b/fms_mo/utils/dq_utils.py index d5622864..2eb51caf 100644 --- a/fms_mo/utils/dq_utils.py +++ b/fms_mo/utils/dq_utils.py @@ -13,14 +13,20 @@ # limitations under the License. """Utility functions for Direct Quantization" (DQ).""" +# Standard +import logging + +logger = logging.getLogger(__name__) + def config_quantize_smooth_layers(qcfg: dict): """Update qcfg with model-dependent config parameters: - qlayer_name_pattern: identifier of transformer layers containing linear layers - to quantize (if any, tracing is bypassed) + to quantize (if any, tracing is bypassed) - qskip_layer_name: full name of linear layers that will not be quantized - smoothq_scale_layers: identifier of linear layers to apply smoothquant on - - smoothq_act_scale_path: path to save/load smoothquant activation scales + - smoothq_act_scale_path: path to save/load smoothquant activation scales, should be kept + Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt"), no need to specify here. Selected model is determined by comparing all architecture identifiers against `model` and `model_type` fields in qcfg. @@ -98,23 +104,10 @@ def config_quantize_smooth_layers(qcfg: dict): [31, 7], ] ] - qcfg["smoothq_act_scale_path"] = "./act_scales/Mixtral-8x7B-v0.1.pt" elif any(model in qcfg["model"] for model in bigcode_architecture): qcfg["qlayer_name_pattern"] = ["transformer.h"] qcfg["smoothq_scale_layers"] = ["c_attn", "c_fc"] # NOTE: supported bigcode models do not need layer skip for large magnitude - if "granite-3b-base-v2" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/granite_3b_base_v2_500_nw.pt" - if "granite-13b-base-v2" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/granite_13b_base_v2.pt" - if "granite-20b-code-base" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt" - if "granite-20b-code-instruct" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt" - if "granite-34b-code-base" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt" - if "granite-34b-code-instruct" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt" elif "roberta" in qcfg["model"]: qcfg["act_scale_path"] = "./act_scales" qcfg["smoothq_scale_layers"] = [ @@ -126,4 +119,7 @@ def config_quantize_smooth_layers(qcfg: dict): qcfg["qskip_layer_name"] = [] qcfg["qlayer_name_pattern"] = ["roberta.encoder"] else: - raise ValueError("The model architecture is not supported for DQ.") + logger.info( + "The model architecture is not supported for DQ. No architecture-specific settings is" + "applied. All Linear layers will be quantized, which may not yield the optimal results." + )