From 005d54bdb83db8f6a61bea294c1307e40ad91227 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Thu, 26 Jun 2025 17:57:40 +0000 Subject: [PATCH 01/11] add DL16 option for LinearFPx (FP8 aiu sim) Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 36 +++++++++++++++- fms_mo/custom_ext_kernels/utils.py | 47 ++++++++++++++++----- fms_mo/modules/linear.py | 24 +++++++++-- 3 files changed, 91 insertions(+), 16 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index beae5a2e..0aa08421 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -114,6 +114,7 @@ def matmul_kernel( stride_cn, chunk_trun_bits, max_acc_bits, # pylint: disable=unused-argument + clamp_acc_to_dl16, truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -182,6 +183,8 @@ def matmul_kernel( ## ------ add chunky LSB rounding/masking -------- if chunk_trun_bits > 0: accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) + if clamp_acc_to_dl16: + accumulator = fp32_clamp_to_dl16(accumulator) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner @@ -226,6 +229,7 @@ def imatmul_kernel( stride_cn, chunk_trun_bits, max_acc_bits, + clamp_acc_to_dl16, # pylint: disable=unused-argument truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -324,6 +328,7 @@ def matmul_kernel_DABC( stride_cn, chunk_trun_bits, max_acc_bits, # pylint: disable=unused-argument + clamp_acc_to_dl16, truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -405,6 +410,8 @@ def matmul_kernel_DABC( ## ------ add chunky LSB rounding/masking -------- if chunk_trun_bits > 0: accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) + if clamp_acc_to_dl16: + accumulator = fp32_clamp_to_dl16(accumulator) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner @@ -438,6 +445,28 @@ def round_and_trun(x, round_bit, trun_mask): return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask) +@triton.jit +def fp32_clamp_to_dl16(x): + """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range.""" + # 1. rounding, add round bit to full uint representation + x = libdevice.float_as_uint(x) + round_bit = 1 << (23 - 9 - 1) + x = libdevice.uint_as_float(x + round_bit) + + # 2. clamp to min/max: + # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf + # (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0 + # min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this + # (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10 + dl16_max = 8581545984.0 + dl16_min = 4.665707820095122e-10 + x = tl.where(x >= dl16_max, float("inf"), x) + x = tl.where(x <= -dl16_max, float("-inf"), x) + x = tl.where(tl.abs(x) < dl16_min, 0, x) + + return x + + def tl_matmul_chunk_truncate( a, b, @@ -448,6 +477,7 @@ def tl_matmul_chunk_truncate( max_acc_bits=32, truncate_then_accumulate=True, cast_output_to_input_dtype=None, + clamp_acc_to_dl16=False, ): """Triton matmul for HW behavior simulation. Supports float and int8. i. variable chunk size (i.e., BLOCK_SIZE_K) @@ -461,7 +491,8 @@ def tl_matmul_chunk_truncate( chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16. max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will clamp each chunk of a*b to [-2**23-1, 2**23]. - (assuming no inf when overflow) + (only used by INT) + clamp_acc_to_dl16(bool): Only used by FP8, whether to clamp local accumulator (FP32) to DL16 truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise c = truncate(a*b+c) cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually @@ -473,7 +504,7 @@ def tl_matmul_chunk_truncate( NOTE: use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for - real model inference. otherwise auto-tune will be triggered in every forward call. + real model inference. otherwise auto-tune may be triggered in every forward call. """ # Check constraints. @@ -584,6 +615,7 @@ def grid(META): c.stride(1), chunk_trun_bits=chunk_trun_bits, max_acc_bits=max_acc_bits, + clamp_acc_to_dl16=clamp_acc_to_dl16, truncate_then_accumulate=truncate_then_accumulate, ACTIVATION=activation, **kernel_config, # if using auto-tune, comment this line out. diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index 0d5fffb0..0ac95c6a 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -870,14 +870,15 @@ def lower_qmodel_triton( model: torch.nn.Module, use_dyn_max_act=False, max_acc_bits=32, + clamp_acc_to_dl16=False, num_lsb_to_truncate=0, chunk_size=32, ): """ - Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else. + Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers. Triton kernel can be used to: 1. test INT8 or FP8 HW performance (kernel is not optimized) - 2. simulate MSB/LSB truncation effect + 2. simulate MSB/LSB truncation effect or special dtype (DL16) accumulation Args: model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy @@ -888,6 +889,8 @@ def lower_qmodel_triton( efficiency at the expense of higher chance of accumulation "overflow". For example, an INT24 accumulator can only hold values ranged from -2^23 to 2^23 -1, as opposed to typical range -2^31 to -2^31 -1. + clamp_acc_to_dl16: clamp local accumulator to DL16 (1-6-9) range. To simulate this special + dtype effect on accumulation. num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is s1e8m23, if we choose to truncate 13 mantissa bits from right most side, i.e. LSB, the resulting number will be s1e8m10, which is TF32. @@ -900,25 +903,47 @@ def lower_qmodel_triton( from torch.ao.quantization.utils import _parent_name # Local - from fms_mo.modules.linear import QLinear, QLinearINT8Deploy + from fms_mo.modules.linear import LinearFPxAcc, QLinear, QLinearINT8Deploy + + # Currently QLinearINT8 has more options in dynamic quantization than LinearFP. Here we resolve + # the differences as a patch solution (will unify the codes in future release) + linFP_dyn_code = ( + "per_token" + if use_dyn_max_act in [-1, -2] + else "per_tensor" + if use_dyn_max_act + else False + ) for name, m in model.named_modules(): - if not isinstance(m, QLinear): + if not isinstance(m, (QLinear, torch.nn.Linear)): continue parent_name, module_name = _parent_name(name) parent_mod = model.get_submodule(parent_name) - qmod = getattr(parent_mod, module_name) - setattr( - parent_mod, - module_name, - QLinearINT8Deploy.from_fms_mo( - qmod, + + # Only support simulations of 1) QLinear -> INT8, 2) nnLinear->FP8 for now + if isinstance(m, QLinear): + new_lin = QLinearINT8Deploy.from_fms_mo( + m, use_int_kernel="triton", use_dynamic_max_act_Qfunc=use_dyn_max_act, max_acc_bits=max_acc_bits, truncate_lsb=num_lsb_to_truncate, chunk_size=chunk_size, - ), + ) + else: + new_lin = LinearFPxAcc.from_nn( + m, + trun_bits=max_acc_bits, + chunk_size=chunk_size, + dynamic_fp8=linFP_dyn_code, + clamp_acc_to_dl16=clamp_acc_to_dl16, + ) + + setattr( + parent_mod, + module_name, + new_lin, ) logger.info(f"\nModel lowering with triton kernel is done.\n{model}") diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 345187b6..54b1efc6 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -1899,7 +1899,16 @@ class LinearFuncFPxFwdBwd(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False): + def forward( + ctx, + x, + weight, + bias=None, + trun_bits=0, + chunk_size=16, + fp8_dyn=False, + clamp_acc_to_dl16=False, + ): assert x.dtype in [torch.float, torch.bfloat16, torch.float16] # input can be 2D or 3D, need to reshape before tl_matmul org_dtype = x.dtype @@ -1916,6 +1925,7 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False ctx.trun_bits = trun_bits ctx.chunk_size = chunk_size ctx.fp8_dyn = fp8_dyn + ctx.clamp_acc_to_dl16 = clamp_acc_to_dl16 if fp8_dyn: # use Q/dQ simulation for now, meaning still compute in fp16/bf16 @@ -1936,6 +1946,7 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False weight.t().to(org_dtype), chunk_trun_bits=trun_bits, chunk_size=chunk_size, + clamp_acc_to_dl16=clamp_acc_to_dl16, ).reshape(target_shape_output) if bias is not None: @@ -1976,6 +1987,7 @@ def backward(ctx, grad_output): x, chunk_trun_bits=trun_bits, chunk_size=chunk_size, + clamp_acc_to_dl16=ctx.clamp_acc_to_dl16, ).to(weight.dtype) # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D grad_input = ( @@ -1984,6 +1996,7 @@ def backward(ctx, grad_output): weight.to(dtype_input), chunk_trun_bits=trun_bits, chunk_size=chunk_size, + clamp_acc_to_dl16=ctx.clamp_acc_to_dl16, ) .reshape(target_shape_grad_input) .to(dtype_input) @@ -1994,7 +2007,7 @@ def backward(ctx, grad_output): else: grad_bias = grad_output_2D.sum(0).to(ctx.bias_dtype) - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None class LinearFPxAcc(torch.nn.Linear): @@ -2016,6 +2029,10 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs): cls (class): The class to be created. nnlin (torch.nn.Linear): The original torch.nn.Linear module. trun_bits (int): truncate [0 to 22] LSBs from FP32 accumulation. + dynamic_fp8: whether to use dynamic quantization for fp8 activations, available options + are ["per_tensor", "per_token", False] + clamp_acc_to_dl16: clamp local accumulator into DL16 range, to simulate the effect of + this special dtype **kwargs: Additional keyword arguments. Returns: @@ -2037,7 +2054,7 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs): lin24acc.trun_bits = trun_bits lin24acc.chunk_size = kwargs.get("chunk_size", False) lin24acc.fp8_dyn = kwargs.get("dynamic_fp8", False) - # available options are ["per_tensor", "per_token"] + lin24acc.clamp_acc_to_dl16 = kwargs.get("clamp_acc_to_dl16", False) if nnlin.bias is not None: lin24acc.bias = nnlin.bias @@ -2052,6 +2069,7 @@ def forward(self, inputs): self.trun_bits, self.chunk_size, self.fp8_dyn, + self.clamp_acc_to_dl16, ) def extra_repr(self) -> str: From a8d6ea9def3ba0f684fc5c7fe38c65328b678231 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 27 Jun 2025 03:07:26 +0000 Subject: [PATCH 02/11] zero out last 13 bits Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 0aa08421..eea50d92 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -448,10 +448,10 @@ def round_and_trun(x, round_bit, trun_mask): @triton.jit def fp32_clamp_to_dl16(x): """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range.""" - # 1. rounding, add round bit to full uint representation + # 1. rounding: add round bit to full uint representation, zero out last 13 bits, back to float x = libdevice.float_as_uint(x) round_bit = 1 << (23 - 9 - 1) - x = libdevice.uint_as_float(x + round_bit) + x = libdevice.uint_as_float(((x + round_bit) >> 13) << 13) # 2. clamp to min/max: # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf From 50648c7a4c2403a6435036add80c29c2c3d081b4 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 27 Jun 2025 03:16:14 +0000 Subject: [PATCH 03/11] bug fix Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index eea50d92..6e7b6ab7 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -184,7 +184,7 @@ def matmul_kernel( if chunk_trun_bits > 0: accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) if clamp_acc_to_dl16: - accumulator = fp32_clamp_to_dl16(accumulator) + accumulator_inner = fp32_clamp_to_dl16(accumulator_inner) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner @@ -411,7 +411,7 @@ def matmul_kernel_DABC( if chunk_trun_bits > 0: accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) if clamp_acc_to_dl16: - accumulator = fp32_clamp_to_dl16(accumulator) + accumulator_inner = fp32_clamp_to_dl16(accumulator_inner) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner From 81fe1bde39104673bbcee901588b0d855a5e641e Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 27 Jun 2025 03:32:15 +0000 Subject: [PATCH 04/11] simplify LinearFPx repr str Signed-off-by: cliu-us --- fms_mo/modules/linear.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 54b1efc6..74c46810 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -2076,10 +2076,17 @@ def extra_repr(self) -> str: """ Returns an alternative string representation of the object. """ - return ( - f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, " - f"trun_bits={self.trun_bits},fp8_dyn={self.fp8_dyn},chunk_size={self.chunk_size}" - ) + repr_str = f"{self.in_features},{self.out_features}" + if self.bias is not None: + repr_str += f",bias={self.bias is not None}" + if self.trun_bits > 0: + repr_str += f",trun_bits={self.trun_bits}" + if self.fp8_dyn: + repr_str += f",fp8_dyn={self.fp8_dyn}" + if self.clamp_acc_to_dl16: + repr_str += f",use_DL16_acc" + repr_str += f",chunk_size={self.chunk_size}" + return repr_str class LinearFuncINT8FwdFP32Bwd(torch.autograd.Function): From c7cd06c8de018cec6b6a25935357674893de1340 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 27 Jun 2025 03:48:58 +0000 Subject: [PATCH 05/11] lower_qmodel_triton() can skip layers if needed Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index 0ac95c6a..75e1493b 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -873,6 +873,7 @@ def lower_qmodel_triton( clamp_acc_to_dl16=False, num_lsb_to_truncate=0, chunk_size=32, + layer_to_exclude=[], ): """ Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers. @@ -916,7 +917,7 @@ def lower_qmodel_triton( ) for name, m in model.named_modules(): - if not isinstance(m, (QLinear, torch.nn.Linear)): + if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude: continue parent_name, module_name = _parent_name(name) parent_mod = model.get_submodule(parent_name) From 774e68e6462046cb1ebde2df16a021e95ad9ba36 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 27 Jun 2025 03:49:23 +0000 Subject: [PATCH 06/11] linting Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/utils.py | 5 ++++- fms_mo/modules/linear.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index 75e1493b..c87d6aba 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -873,7 +873,7 @@ def lower_qmodel_triton( clamp_acc_to_dl16=False, num_lsb_to_truncate=0, chunk_size=32, - layer_to_exclude=[], + layer_to_exclude=None, ): """ Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers. @@ -916,6 +916,9 @@ def lower_qmodel_triton( else False ) + if layer_to_exclude is None: + layer_to_exclude = [] + for name, m in model.named_modules(): if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude: continue diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 74c46810..ee1f7ca0 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -2084,7 +2084,7 @@ def extra_repr(self) -> str: if self.fp8_dyn: repr_str += f",fp8_dyn={self.fp8_dyn}" if self.clamp_acc_to_dl16: - repr_str += f",use_DL16_acc" + repr_str += ",use_DL16_acc" repr_str += f",chunk_size={self.chunk_size}" return repr_str From 68de6c1eec2e3b01a62d1f0706c049d311f11e5b Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 27 Jun 2025 04:00:51 +0000 Subject: [PATCH 07/11] bug fix Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/utils.py | 2 +- fms_mo/modules/linear.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index c87d6aba..203a4217 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -938,7 +938,7 @@ def lower_qmodel_triton( else: new_lin = LinearFPxAcc.from_nn( m, - trun_bits=max_acc_bits, + trun_bits=num_lsb_to_truncate, chunk_size=chunk_size, dynamic_fp8=linFP_dyn_code, clamp_acc_to_dl16=clamp_acc_to_dl16, diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index ee1f7ca0..8bd4711e 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -2047,7 +2047,7 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs): nnlin.in_features, nnlin.out_features, bias=nnlin.bias is not None, - device=target_device, + device="meta", ) lin24acc.weight = nnlin.weight From 275d47df5322d90d9b36bd6f0b3b859ad78758d5 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Mon, 7 Jul 2025 16:54:10 -0400 Subject: [PATCH 08/11] new triton verson doesn't like 0xFFFFFFFF as a const Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 9 ++++--- fms_mo/custom_ext_kernels/utils.py | 6 +++++ fms_mo/dq.py | 29 ++++++++++++--------- fms_mo/modules/linear.py | 8 ++++-- fms_mo/training_args.py | 11 ++++++-- 5 files changed, 42 insertions(+), 21 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 6e7b6ab7..3105ca44 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -164,7 +164,7 @@ def matmul_kernel( # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 - trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32) + trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32) round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 ## --------------------------------------------------------- @@ -386,7 +386,7 @@ def matmul_kernel_DABC( # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 - trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32) + trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32) round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 ## --------------------------------------------------------- @@ -448,10 +448,11 @@ def round_and_trun(x, round_bit, trun_mask): @triton.jit def fp32_clamp_to_dl16(x): """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range.""" - # 1. rounding: add round bit to full uint representation, zero out last 13 bits, back to float + # 1. rounding: add round bit, zero out last 13 bits, back to float x = libdevice.float_as_uint(x) round_bit = 1 << (23 - 9 - 1) - x = libdevice.uint_as_float(((x + round_bit) >> 13) << 13) + mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32) + x = libdevice.uint_as_float((x + round_bit) & mask_13x0) # 2. clamp to min/max: # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index 203a4217..65ca617a 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -918,6 +918,12 @@ def lower_qmodel_triton( if layer_to_exclude is None: layer_to_exclude = [] + elif isinstance(layer_to_exclude, str): + layer_to_exclude = [ + layer_to_exclude, + ] + elif not isinstance(layer_to_exclude, (list, tuple)): + raise RuntimeError("layer_to_exclude has to be either str, list, or tuple.") for name, m in model.named_modules(): if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude: diff --git a/fms_mo/dq.py b/fms_mo/dq.py index eef16049..6fd3bb27 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -216,17 +216,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): act_scales = get_act_scales(model, dq_dataloader, qcfg) torch.save(act_scales, scale_file) - qmodel_prep( - model, - dq_dataloader, - qcfg, - use_layer_name_pattern_matching=use_layer_name_pattern_matching, - use_dynamo=use_dynamo, - dev=dev, - save_fname="dq", - ) - logger.info(f"Quantized model {model}") - logger.info("==" * 20) + if fms_mo_args.aiu_sim_triton != "fp8": + qmodel_prep( + model, + dq_dataloader, + qcfg, + use_layer_name_pattern_matching=use_layer_name_pattern_matching, + use_dynamo=use_dynamo, + dev=dev, + save_fname="dq", + ) + logger.info(f"Quantized model {model}") + logger.info("==" * 20) if qcfg["smoothq"]: logger.info("Starting to apply smooth scale") @@ -260,14 +261,16 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): tokenizer.save_pretrained(opt_args.output_dir) if fms_mo_args.aiu_sim_triton: + # NOTE plz apply correct HW settings here, defaults are not real HW params lower_qmodel_triton( model, use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False, max_acc_bits=qcfg.get("max_acc_bits", 32), num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0), - chunk_size=qcfg.get("chunk_size", 1024), + chunk_size=qcfg.get("chunk_size", 32), # 1024 + clamp_acc_to_dl16=False, # fms_mo_args.aiu_sim_triton == "fp8" + # layer_to_exclude=["lm_head",] ) - if fms_mo_args.eval_ppl: path_test = Path(data_args.test_data_path) arrow_files = list(path_test.glob("*.arrow")) diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 8bd4711e..23457d0f 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -1934,8 +1934,12 @@ def forward( ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max reduce_dim = None if fp8_dyn == "per_tensor" else 1 - x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max - w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max + x_scale = ( + x.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max + ).clamp(min=1e-5) + w_scale = ( + weight.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max + ).clamp(min=1e-5) x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index e7beafc6..344cf244 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -181,8 +181,15 @@ class FMSMOArguments(TypeChecker): default=2048, metadata={"help": "input sequence length after tokenization"} ) eval_ppl: bool = field(default=False) - aiu_sim_triton: bool = field( - default=False, metadata={"help": ("AIU simulation with triton kernel")} + aiu_sim_triton: str = field( + default=None, + metadata={ + "help": ( + "AIU simulation with triton kernel. ['int8', 'fp8', None]\n" + "'int8' mode will trigger qmodel_prep() and swap QLinears" + "'fp8' mode will directly replace existing nn.Linears" + ) + }, ) recompute_narrow_weights: bool = field( default=False, From b685ea8b921b1aa9f555f89b387b36f2cc75d5c7 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Wed, 9 Jul 2025 05:40:05 +0000 Subject: [PATCH 09/11] fix triton DL16 aiu sim with subnorm flushing Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 108 +++++++++++-------- fms_mo/dq.py | 31 +++++- fms_mo/modules/linear.py | 11 ++ fms_mo/quant/ptq.py | 113 +++++++++++++++++++- fms_mo/training_args.py | 2 +- fms_mo/utils/dq_utils.py | 28 +++-- 6 files changed, 229 insertions(+), 64 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 3105ca44..41d46ca2 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -160,13 +160,8 @@ def matmul_kernel( # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - ## ------ prepare LSB rounding/truncation masks ------- - # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b - # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 - # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 - trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32) - round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 - ## --------------------------------------------------------- + ## ------ prepare LSB rounding/truncation masks outside the for loop ------- + round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. @@ -181,10 +176,10 @@ def matmul_kernel( # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp ## ------ add chunky LSB rounding/masking -------- - if chunk_trun_bits > 0: - accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) - if clamp_acc_to_dl16: - accumulator_inner = fp32_clamp_to_dl16(accumulator_inner) + if clamp_acc_to_dl16 or chunk_trun_bits > 0: + accumulator_inner = round_and_trun( + accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16 + ) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner @@ -382,13 +377,8 @@ def matmul_kernel_DABC( # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0) - ## ------ prepare LSB rounding/truncation masks ------- - # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b - # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 - # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 - trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32) - round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 - ## --------------------------------------------------------- + ## ------ prepare LSB rounding/truncation masks outside the for loop ------- + round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A, B, and C, generate a mask by checking the K dimension. @@ -408,10 +398,10 @@ def matmul_kernel_DABC( # precision as well, hence, could lose some precision! ## ------ add chunky LSB rounding/masking -------- - if chunk_trun_bits > 0: - accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask) - if clamp_acc_to_dl16: - accumulator_inner = fp32_clamp_to_dl16(accumulator_inner) + if clamp_acc_to_dl16 or chunk_trun_bits > 0: + accumulator_inner = round_and_trun( + accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16 + ) ## --------------------------------------------------------- if truncate_then_accumulate: accumulator += accumulator_inner @@ -440,34 +430,64 @@ def leaky_relu(x): @triton.jit -def round_and_trun(x, round_bit, trun_mask): - """Round and truncate (usually for accumulator).""" - return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask) +def round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16): + """ + Rounding and LSB truncation masks only need to be generated once. + These mask will be applied on "inner" accumulator, which is alway FP32 (e8m23). We may truncate + up to 23b for mantissa. If DL16/DL8, pay attention to exponent bias. + Examples: 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 + 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 + """ + if clamp_acc_to_dl16: + # DL16 is e6m9, hence, truncate 23 - 9 = 14 bits + chunk_trun_bits = 14 + round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 + trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32) + return round_bit, trun_mask @triton.jit -def fp32_clamp_to_dl16(x): - """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range.""" - # 1. rounding: add round bit, zero out last 13 bits, back to float - x = libdevice.float_as_uint(x) - round_bit = 1 << (23 - 9 - 1) - mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32) - x = libdevice.uint_as_float((x + round_bit) & mask_13x0) - - # 2. clamp to min/max: - # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf - # (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0 - # min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this - # (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10 - dl16_max = 8581545984.0 - dl16_min = 4.665707820095122e-10 - x = tl.where(x >= dl16_max, float("inf"), x) - x = tl.where(x <= -dl16_max, float("-inf"), x) - x = tl.where(tl.abs(x) < dl16_min, 0, x) - +def round_and_trun(x, round_bit, trun_mask, clamp_acc_to_dl16): + """Round and truncate (usually for accumulator).""" + x = libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask) + + if clamp_acc_to_dl16: + # clamp to DL16 min/max: + # max = 2^32 * 1.(1111 1111 0)_base2 = 2^32*(2 - 2^-9) = 8581545984.0 + # greater than this will become +inf (or -inf) + # min = 2^-31 * 1.(0000 0000 1)_base2 = 2^-31*(1 + 2^-9)> = 4.665707820095122e-10 + # smaller than this will become 0 + dl16_max = 8581545984.0 + dl16_min = 4.665707820095122e-10 + x = tl.where(x >= dl16_max, float("inf"), x) + x = tl.where(x <= -dl16_max, float("-inf"), x) + x = tl.where(tl.abs(x) < dl16_min, 0, x) return x +# @triton.jit +# def fp32_clamp_to_dl16(x): +# """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range.""" +# # 1. rounding: add round bit, zero out last 13 bits, back to float +# x = libdevice.float_as_uint(x) +# round_bit = 1 << (23 - 9 - 1) +# mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32) +# x = libdevice.uint_as_float((x + round_bit) & mask_13x0) + +# # 2. clamp to min/max: +# # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf +# # (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0 +# # min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this +# # (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10 +# dl16_max = 8581545984.0 +# dl16_min = 4.665707820095122e-10 +# x = tl.where(x >= dl16_max, float("inf"), x) +# x = tl.where(x <= -dl16_max, float("-inf"), x) +# x = tl.where(tl.abs(x) < dl16_min, 0, x) + +# return x + + def tl_matmul_chunk_truncate( a, b, diff --git a/fms_mo/dq.py b/fms_mo/dq.py index 6fd3bb27..bee79fa9 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -161,6 +161,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): # config layers to skip, smooth scale config_quantize_smooth_layers(qcfg) + use_dynamo = True + # use dynamo as default unless really needed, False -> fallback to TorchScript tracing if any(x != 32 for x in attn_bits): logger.info("Quantize attention bmms or kvcache, will use dynamo for prep") use_layer_name_pattern_matching = False @@ -168,11 +170,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): assert ( qcfg["qlayer_name_pattern"] == [] ), "ensure nothing in qlayer_name_pattern when use dynamo" - use_dynamo = True else: logger.info("Attention bmms will not be quantized.") use_layer_name_pattern_matching = True - use_dynamo = False qcfg["seq_len"] = block_size qcfg["model"] = model_args.model_name_or_path @@ -271,6 +271,33 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): clamp_acc_to_dl16=False, # fms_mo_args.aiu_sim_triton == "fp8" # layer_to_exclude=["lm_head",] ) + # [CL] -------- record W, A, qW, qA with hooks ---------------- + # from fms_mo.modules.linear import QLinear, QLinearINT8Deploy + # from fms_mo.quant.ptq import HookRecPostQuantInOut + # cache_dict = {} + # hook_handles = [] + # for n, m in model.named_modules(): + # if not isinstance(m, (QLinear, QLinearINT8Deploy, torch.nn.Linear)): + # continue + + # m.mod_name = n + # hook_handles.append( + # m.register_forward_hook( HookRecPostQuantInOut(cache_dict, n)) + # ) + + # data_mb = next(iter(eval_dataloader)) + # with torch.no_grad(): + # model(**data_mb) + + # for h in hook_handles: + # h.remove() + + # torch.save( + # cache_dict, + # f"roberta_sqv2_data_dump_{qcfg['qa_mode']}_{qcfg['qw_mode']}_chunk64_lsb{args.aiu_int_lsb_trun}_dq.pt" + # ) + # return + if fms_mo_args.eval_ppl: path_test = Path(data_args.test_data_path) arrow_files = list(path_test.glob("*.arrow")) diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 124dfd31..65f9378d 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -1926,6 +1926,7 @@ def forward( ctx.chunk_size = chunk_size ctx.fp8_dyn = fp8_dyn ctx.clamp_acc_to_dl16 = clamp_acc_to_dl16 + ctx.dl8_min = 0.0087890625 if fp8_dyn: # use Q/dQ simulation for now, meaning still compute in fp16/bf16 @@ -1943,6 +1944,11 @@ def forward( x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale + if clamp_acc_to_dl16: + # NOTE For DL8@DL8 acc in DL16, as DL8 doesn't support subnorm numbers like PyTorch + # (whose real min for e4m3fn is 2^-9), need to flush subnorm numbers to 0 + x.masked_fill_(x < ctx.dl8_min, 0) + weight.masked_fill_(weight < ctx.dl8_min, 0) # triton kernel assumes 2D inputs and cast the return to input.dtype output = tl_matmul( @@ -1983,6 +1989,11 @@ def backward(ctx, grad_output): grad_output_2D = (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to( grad_output.dtype ) * grad_out_scale + if ctx.clamp_acc_to_dl16: + # flush subnorm numbers to 0 as DL8 doesn't support it + x.masked_fill_(x < ctx.dl8_min, 0) + weight.masked_fill_(weight < ctx.dl8_min, 0) + grad_output_2D.masked_fill_(grad_output_2D < ctx.dl8_min, 0) # Compute grad_weight, shape = [out, in] # NOTE: this triton kernel requires A matrix to be contiguous diff --git a/fms_mo/quant/ptq.py b/fms_mo/quant/ptq.py index 7801284b..de2c5729 100644 --- a/fms_mo/quant/ptq.py +++ b/fms_mo/quant/ptq.py @@ -42,6 +42,7 @@ # Local from fms_mo.modules import QBmm, QLinear from fms_mo.modules.conv import QConv2dPTQv2 +from fms_mo.modules.linear import LinearFPxAcc, QLinearINT8Deploy from fms_mo.quant.quantizers import ( AdaRoundQuantizer, Qdynamic, @@ -481,8 +482,118 @@ def __call__(self, mod, inputs, *args, **_kwargs): assert not self.stop_after_rec -# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module +class HookRecPostQuantInOut(torch.nn.Module): + """Another simplified hook to check post-quantized input/output, e.g. within +-127 for INT8.""" + + def __init__(self, cache_dict={}, mod_name=None): + super().__init__() + self.cache_dict = cache_dict + self.mod_name = mod_name + name_split = mod_name.split(".") + self.lay_idx = int(name_split[3]) + self.lay_key = name_split[6] + + self.cache_dev = "cpu" + # prepare empty dict for later use + self.cache_dict[mod_name] = {} + self.fwd_mapping = { + LinearFPxAcc: self.call_func_for_fpxacc, + QLinear: self.call_func_for_qlinear, + QLinearINT8Deploy: self.call_func_for_qlinear_int, + torch.nn.Linear: self.call_func_for_nnlinear, + } + + def call_func_for_fpxacc(self, mod, inputs, outputs, **_kwargs): + raise NotImplementedError + + def call_func_for_qlinear(self, mod, inputs, outputs, **_kwargs): + lay_idx = self.lay_idx + lay_key = self.lay_key + mod_name = self.mod_name + cache_dict = self.cache_dict + + act_max = inputs[0].abs().amax(dim=[d for d in range(len(inputs[0].shape) - 1)]) + # mod.smoothq_act_scale + w_max = mod.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5) + is_smq_layer = not torch.all(act_max == 0).item() + # smoothQ scale = smoothq_act_scale**alpha / weight_scale**(1.0 - alpha) + # smoothq_scale = mod.get_smoothq_scale(inputs[0]) + smoothq_scale = getattr(mod, "smq_scale", 1.0) + # "smq_scale" only available in QLin_INT8 + + with torch.no_grad(): + smoothed_inp = inputs[0] / smoothq_scale + smoothed_w = mod.weight * smoothq_scale + + # this is assuming pertokenmax quantizer, NOTE calc quant scale after smoothing + absmax = smoothed_inp.abs().max(dim=-1, keepdim=True)[0] + qa_scale = absmax.clamp(min=1e-5) / 127 + qinput = torch.round(smoothed_inp / qa_scale).clamp(-127, 127) + # should clamp to -128? + if mod.qa_mode == "pertokenmax": + # doesnt implement dequant=False yet, do it manually + cva = mod.quantize_feature.clip_val + qa_scale = cva.clamp(min=1e-5).div(127) + qinput = smoothed_inp.div(qa_scale).round() + else: + mod.quantize_feature.dequantize = False + qinput = mod.quantize_feature(smoothed_inp) + mod.quantize_feature.dequantize = True + + # also record quantized, smoothed W in INT8, calc both maxperCh and SAWBperCh + cvw = mod.quantize_weight.clip_val + scale_w = cvw / 127 + mod.quantize_weight.dequantize = False + qw = mod.quantize_weight(smoothed_w) + mod.quantize_weight.dequantize = True + + # inputs is a tuple, QLinear only has 1 valid input + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["cva"] = cva.to(self.cache_dev) + cache_dict[mod_name]["cvw"] = cvw.to(self.cache_dev) + cache_dict[mod_name]["smoothed_input"] = smoothed_inp.to(self.cache_dev) + cache_dict[mod_name]["smoothed_weight"] = smoothed_w.to(self.cache_dev) + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) + # NOTE in INT8, *scales if need dQ + cache_dict[mod_name]["qweight"] = qw.to(self.cache_dev) + # torch.round(smoothed_w.T/scale_w).clamp(-127, 127).to(self.cache_dev) + # cache_dict[mod_name]["qoutput"] = outputs.to(self.cache_dev) + + def call_func_for_qlinear_int(self, mod, inputs, outputs, **_kwargs): + smoothq_scale = getattr(mod, "smq_scale", 1.0) + mod_name = self.mod_name + cache_dict = self.cache_dict + with torch.no_grad(): + if mod.useDynMaxQfunc in [-1, -2]: + qinput = mod.qa_dynamic_max_qfunc(inputs[0]) + elif mod.use_fake_zero_shift: + qinput = mod.qa_dyn_max_fake_zero_shift(inputs[0]) + elif mod.usePTnativeQfunc: + qinput = mod.qa_raw_qfunc(inputs[0]) + else: + qinput = mod.qa_fmo_mo_qfunc(inputs[0]) + + # inputs is a tuple, QLinear only has 1 valid input + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["cva"] = mod.cvs[0].to(self.cache_dev) + cache_dict[mod_name]["cvw"] = mod.cvs[2].to(self.cache_dev) + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) + cache_dict[mod_name]["qweight"] = mod.weight.to(self.cache_dev) + + def call_func_for_nnlinear(self, mod, inputs, outputs, **_kwargs): + mod_name = self.mod_name + cache_dict = self.cache_dict + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) + cache_dict[mod_name]["weight"] = mod.weight.to(self.cache_dev) + + def __call__(self, mod, inputs, outputs, **_kwargs): + self.fwd_mapping[type(mod)](mod, inputs, outputs, **_kwargs) + + class PTQHookRecQOut(nn.Module): + """This hook is for ptq_loss_func == 'fisher_diag' and will temporarily hold the "Q_out" of the + module""" + def __init__(self, qcfg): super().__init__() self.qcfg = qcfg diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index 54a4adb9..661f72bd 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -192,7 +192,7 @@ class FMSMOArguments(TypeChecker): default=2048, metadata={"help": "input sequence length after tokenization"} ) eval_ppl: bool = field(default=False) - aiu_sim_triton: str = field( + aiu_sim_triton: Optional[str] = field( default=None, metadata={ "help": ( diff --git a/fms_mo/utils/dq_utils.py b/fms_mo/utils/dq_utils.py index d5622864..2eb51caf 100644 --- a/fms_mo/utils/dq_utils.py +++ b/fms_mo/utils/dq_utils.py @@ -13,14 +13,20 @@ # limitations under the License. """Utility functions for Direct Quantization" (DQ).""" +# Standard +import logging + +logger = logging.getLogger(__name__) + def config_quantize_smooth_layers(qcfg: dict): """Update qcfg with model-dependent config parameters: - qlayer_name_pattern: identifier of transformer layers containing linear layers - to quantize (if any, tracing is bypassed) + to quantize (if any, tracing is bypassed) - qskip_layer_name: full name of linear layers that will not be quantized - smoothq_scale_layers: identifier of linear layers to apply smoothquant on - - smoothq_act_scale_path: path to save/load smoothquant activation scales + - smoothq_act_scale_path: path to save/load smoothquant activation scales, should be kept + Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt"), no need to specify here. Selected model is determined by comparing all architecture identifiers against `model` and `model_type` fields in qcfg. @@ -98,23 +104,10 @@ def config_quantize_smooth_layers(qcfg: dict): [31, 7], ] ] - qcfg["smoothq_act_scale_path"] = "./act_scales/Mixtral-8x7B-v0.1.pt" elif any(model in qcfg["model"] for model in bigcode_architecture): qcfg["qlayer_name_pattern"] = ["transformer.h"] qcfg["smoothq_scale_layers"] = ["c_attn", "c_fc"] # NOTE: supported bigcode models do not need layer skip for large magnitude - if "granite-3b-base-v2" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/granite_3b_base_v2_500_nw.pt" - if "granite-13b-base-v2" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/granite_13b_base_v2.pt" - if "granite-20b-code-base" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt" - if "granite-20b-code-instruct" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt" - if "granite-34b-code-base" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt" - if "granite-34b-code-instruct" in qcfg["model"]: - qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt" elif "roberta" in qcfg["model"]: qcfg["act_scale_path"] = "./act_scales" qcfg["smoothq_scale_layers"] = [ @@ -126,4 +119,7 @@ def config_quantize_smooth_layers(qcfg: dict): qcfg["qskip_layer_name"] = [] qcfg["qlayer_name_pattern"] = ["roberta.encoder"] else: - raise ValueError("The model architecture is not supported for DQ.") + logger.info( + "The model architecture is not supported for DQ. No architecture-specific settings is" + "applied. All Linear layers will be quantized, which may not yield the optimal results." + ) From 52833235f91f7209e7c534e6703636c384a83f3d Mon Sep 17 00:00:00 2001 From: cliu-us Date: Wed, 9 Jul 2025 14:16:41 -0400 Subject: [PATCH 10/11] fix DL8/DL16 bugs and a couple other minor bugs fix Signed-off-by: cliu-us --- fms_mo/dq.py | 28 +----------- fms_mo/fx/dynamo_utils.py | 12 ++++-- fms_mo/fx/utils.py | 4 +- fms_mo/modules/linear.py | 90 ++++++++++++++++++++++++--------------- 4 files changed, 67 insertions(+), 67 deletions(-) diff --git a/fms_mo/dq.py b/fms_mo/dq.py index bee79fa9..eb49bc30 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -268,35 +268,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): max_acc_bits=qcfg.get("max_acc_bits", 32), num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0), chunk_size=qcfg.get("chunk_size", 32), # 1024 - clamp_acc_to_dl16=False, # fms_mo_args.aiu_sim_triton == "fp8" + clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8", # layer_to_exclude=["lm_head",] ) - # [CL] -------- record W, A, qW, qA with hooks ---------------- - # from fms_mo.modules.linear import QLinear, QLinearINT8Deploy - # from fms_mo.quant.ptq import HookRecPostQuantInOut - # cache_dict = {} - # hook_handles = [] - # for n, m in model.named_modules(): - # if not isinstance(m, (QLinear, QLinearINT8Deploy, torch.nn.Linear)): - # continue - - # m.mod_name = n - # hook_handles.append( - # m.register_forward_hook( HookRecPostQuantInOut(cache_dict, n)) - # ) - - # data_mb = next(iter(eval_dataloader)) - # with torch.no_grad(): - # model(**data_mb) - - # for h in hook_handles: - # h.remove() - - # torch.save( - # cache_dict, - # f"roberta_sqv2_data_dump_{qcfg['qa_mode']}_{qcfg['qw_mode']}_chunk64_lsb{args.aiu_int_lsb_trun}_dq.pt" - # ) - # return if fms_mo_args.eval_ppl: path_test = Path(data_args.test_data_path) diff --git a/fms_mo/fx/dynamo_utils.py b/fms_mo/fx/dynamo_utils.py index 62967ec2..f7cd61ea 100644 --- a/fms_mo/fx/dynamo_utils.py +++ b/fms_mo/fx/dynamo_utils.py @@ -1180,14 +1180,20 @@ def cus_backend_model_analyzer( if is_transformers: # NOTE simplified method to determine 1st/last modules for transformers. # will not work if model has multiple parallel heads at the end, e.g. obj det - def call_seq_hook(mod, *_args, **_kwargs): - qcfg["mod_call_seq"].append(lut_weight2modname[mod.weight]) + def call_seq_hook(mod, *_args, **kwargs): + mod_name = kwargs.get("mod_name", lut_weight2modname.get(mod.weight, None)) + if mod_name is None: + raise RuntimeError("cannot determine module name, plz check model.") + + qcfg["mod_call_seq"].append(mod_name) h_hooks = [] qcfg["mod_call_seq"] = [] for n, m in model.named_modules(): if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)): - h_hooks.append(m.register_forward_hook(call_seq_hook)) + h_hooks.append( + m.register_forward_hook(partial(call_seq_hook, mod_name=n)) + ) with torch.no_grad(): run_fwd_once(model, sample_inp) diff --git a/fms_mo/fx/utils.py b/fms_mo/fx/utils.py index 357a877b..45abe184 100644 --- a/fms_mo/fx/utils.py +++ b/fms_mo/fx/utils.py @@ -461,14 +461,14 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False): w_mat.numel() * w_mat.element_size() + b_mat.numel() * b_mat.element_size() ) - w_dtype = w_mat.dtype + w_dtype = str(w_mat.dtype) w_shape = w_mat.shape elif isinstance(w, torch.Tensor): mem_use = w.numel() * w.element_size() if hasattr(m, "bias") and m.bias is not None: mem_use += m.bias.numel() * m.bias.element_size() - w_dtype = w.dtype + w_dtype = str(w.dtype) w_shape = w.shape if w_shape: diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 65f9378d..3a39bb30 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -1926,14 +1926,16 @@ def forward( ctx.chunk_size = chunk_size ctx.fp8_dyn = fp8_dyn ctx.clamp_acc_to_dl16 = clamp_acc_to_dl16 + ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max + ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max ctx.dl8_min = 0.0087890625 + x_scale = torch.tensor(1.0, device=x.device, dtype=org_dtype) + w_scale = x_scale.clone() if fp8_dyn: # use Q/dQ simulation for now, meaning still compute in fp16/bf16 # if choose per_token for input, use per_channel for W # (W saved as [out, in], reduce inCh-dim, => reduce_dim=1) - ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max - ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max reduce_dim = None if fp8_dyn == "per_tensor" else 1 x_scale = ( x.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max @@ -1942,22 +1944,30 @@ def forward( weight.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max ).clamp(min=1e-5) - x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale - weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale + x = (x / x_scale).to(torch.float8_e4m3fn).to(torch.float32) + weight = (weight / w_scale).to(torch.float8_e4m3fn).to(torch.float32) if clamp_acc_to_dl16: - # NOTE For DL8@DL8 acc in DL16, as DL8 doesn't support subnorm numbers like PyTorch - # (whose real min for e4m3fn is 2^-9), need to flush subnorm numbers to 0 - x.masked_fill_(x < ctx.dl8_min, 0) - weight.masked_fill_(weight < ctx.dl8_min, 0) + # at this point, x and W are clamped to PT's FP8 range (2^-9 to 448). But since DL8 + # doesn't support subnorm like PyTorch, need to flush subnorms to 0 BEFORE descaling + x.masked_fill_(x.abs() < ctx.dl8_min, 0) + weight.masked_fill_(weight.abs() < ctx.dl8_min, 0) # triton kernel assumes 2D inputs and cast the return to input.dtype - output = tl_matmul( - x, - weight.t().to(org_dtype), - chunk_trun_bits=trun_bits, - chunk_size=chunk_size, - clamp_acc_to_dl16=clamp_acc_to_dl16, - ).reshape(target_shape_output) + output = ( + ( + tl_matmul( + x, + weight.t(), + chunk_trun_bits=trun_bits, + chunk_size=chunk_size, + clamp_acc_to_dl16=clamp_acc_to_dl16, + ) + * x_scale + * w_scale.t() + ) + .to(org_dtype) + .reshape(target_shape_output) + ) if bias is not None: output = output + bias.to(org_dtype) @@ -1977,6 +1987,8 @@ def backward(ctx, grad_output): target_shape_grad_input = grad_output.shape[:-1] + (in_dim,) grad_output_2D = grad_output.reshape(-1, out_dim).to(dtype_input) + x_scale = torch.tensor(1.0, device=x.device, dtype=dtype_input) + w_scale = x_scale.clone() if ctx.fp8_dyn: reduce_dim = None if ctx.fp8_dyn == "per_tensor" else 1 x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max @@ -1984,37 +1996,45 @@ def backward(ctx, grad_output): # always assume perT in this case grad_out_scale = grad_output_2D.abs().amax(dim=None) / ctx.fp8_e5m2_max - x = (x / x_scale).to(torch.float8_e5m2).to(dtype_input) * x_scale - weight = (weight / w_scale).to(torch.float8_e5m2).to(weight.dtype) * w_scale - grad_output_2D = (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to( - grad_output.dtype - ) * grad_out_scale + x = (x / x_scale).to(torch.float8_e5m2).to(torch.float) + weight = (weight / w_scale).to(torch.float8_e5m2).to(torch.float) + grad_output_2D = ( + (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to(torch.float) + ) if ctx.clamp_acc_to_dl16: # flush subnorm numbers to 0 as DL8 doesn't support it - x.masked_fill_(x < ctx.dl8_min, 0) - weight.masked_fill_(weight < ctx.dl8_min, 0) - grad_output_2D.masked_fill_(grad_output_2D < ctx.dl8_min, 0) + x.masked_fill_(x.abs() < ctx.dl8_min, 0) + weight.masked_fill_(weight.abs() < ctx.dl8_min, 0) + grad_output_2D.masked_fill_(grad_output_2D.abs() < ctx.dl8_min, 0) # Compute grad_weight, shape = [out, in] # NOTE: this triton kernel requires A matrix to be contiguous - grad_weight = tl_matmul( - grad_output_2D.transpose(0, 1).contiguous(), - x, - chunk_trun_bits=trun_bits, - chunk_size=chunk_size, - clamp_acc_to_dl16=ctx.clamp_acc_to_dl16, - ).to(weight.dtype) - # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D - grad_input = ( + grad_weight = ( tl_matmul( - grad_output_2D, - weight.to(dtype_input), + grad_output_2D.transpose(0, 1).contiguous(), + x, chunk_trun_bits=trun_bits, chunk_size=chunk_size, clamp_acc_to_dl16=ctx.clamp_acc_to_dl16, ) - .reshape(target_shape_grad_input) + * grad_out_scale.t() + * x_scale + ).to(weight.dtype) + # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D + grad_input = ( + ( + tl_matmul( + grad_output_2D, + weight, + chunk_trun_bits=trun_bits, + chunk_size=chunk_size, + clamp_acc_to_dl16=ctx.clamp_acc_to_dl16, + ) + * grad_out_scale + * w_scale + ) .to(dtype_input) + .reshape(target_shape_grad_input) ) if not ctx.has_bias: From 1bbf139571712a2892e57fa03699cf0c40fdbad9 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Thu, 10 Jul 2025 15:49:18 +0000 Subject: [PATCH 11/11] cleaned up debug codes Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 23 --------------------- 1 file changed, 23 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 41d46ca2..bcba22ca 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -465,29 +465,6 @@ def round_and_trun(x, round_bit, trun_mask, clamp_acc_to_dl16): return x -# @triton.jit -# def fp32_clamp_to_dl16(x): -# """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range.""" -# # 1. rounding: add round bit, zero out last 13 bits, back to float -# x = libdevice.float_as_uint(x) -# round_bit = 1 << (23 - 9 - 1) -# mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32) -# x = libdevice.uint_as_float((x + round_bit) & mask_13x0) - -# # 2. clamp to min/max: -# # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf -# # (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0 -# # min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this -# # (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10 -# dl16_max = 8581545984.0 -# dl16_min = 4.665707820095122e-10 -# x = tl.where(x >= dl16_max, float("inf"), x) -# x = tl.where(x <= -dl16_max, float("-inf"), x) -# x = tl.where(tl.abs(x) < dl16_min, 0, x) - -# return x - - def tl_matmul_chunk_truncate( a, b,