Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions fms_mo/custom_ext_kernels/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def matmul_kernel(
stride_cn,
chunk_trun_bits,
max_acc_bits, # pylint: disable=unused-argument
clamp_acc_to_dl16,
truncate_then_accumulate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
Expand Down Expand Up @@ -159,13 +160,8 @@ def matmul_kernel(
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
## ------ prepare LSB rounding/truncation masks -------
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
## ---------------------------------------------------------
## ------ prepare LSB rounding/truncation masks outside the for loop -------
round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
Expand All @@ -180,8 +176,10 @@ def matmul_kernel(
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp

## ------ add chunky LSB rounding/masking --------
if chunk_trun_bits > 0:
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
if clamp_acc_to_dl16 or chunk_trun_bits > 0:
accumulator_inner = round_and_trun(
accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16
)
## ---------------------------------------------------------
if truncate_then_accumulate:
accumulator += accumulator_inner
Expand Down Expand Up @@ -226,6 +224,7 @@ def imatmul_kernel(
stride_cn,
chunk_trun_bits,
max_acc_bits,
clamp_acc_to_dl16, # pylint: disable=unused-argument
truncate_then_accumulate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
Expand Down Expand Up @@ -324,6 +323,7 @@ def matmul_kernel_DABC(
stride_cn,
chunk_trun_bits,
max_acc_bits, # pylint: disable=unused-argument
clamp_acc_to_dl16,
truncate_then_accumulate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
Expand Down Expand Up @@ -377,13 +377,8 @@ def matmul_kernel_DABC(
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
## ------ prepare LSB rounding/truncation masks -------
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
## ---------------------------------------------------------
## ------ prepare LSB rounding/truncation masks outside the for loop -------
round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A, B, and C, generate a mask by checking the K dimension.
Expand All @@ -403,8 +398,10 @@ def matmul_kernel_DABC(
# precision as well, hence, could lose some precision!

## ------ add chunky LSB rounding/masking --------
if chunk_trun_bits > 0:
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
if clamp_acc_to_dl16 or chunk_trun_bits > 0:
accumulator_inner = round_and_trun(
accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16
)
## ---------------------------------------------------------
if truncate_then_accumulate:
accumulator += accumulator_inner
Expand Down Expand Up @@ -433,9 +430,39 @@ def leaky_relu(x):


@triton.jit
def round_and_trun(x, round_bit, trun_mask):
def round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16):
"""
Rounding and LSB truncation masks only need to be generated once.
These mask will be applied on "inner" accumulator, which is alway FP32 (e8m23). We may truncate
up to 23b for mantissa. If DL16/DL8, pay attention to exponent bias.
Examples: 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
"""
if clamp_acc_to_dl16:
# DL16 is e6m9, hence, truncate 23 - 9 = 14 bits
chunk_trun_bits = 14
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32)
return round_bit, trun_mask


@triton.jit
def round_and_trun(x, round_bit, trun_mask, clamp_acc_to_dl16):
"""Round and truncate (usually for accumulator)."""
return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
x = libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)

if clamp_acc_to_dl16:
# clamp to DL16 min/max:
# max = 2^32 * 1.(1111 1111 0)_base2 = 2^32*(2 - 2^-9) = 8581545984.0
# greater than this will become +inf (or -inf)
# min = 2^-31 * 1.(0000 0000 1)_base2 = 2^-31*(1 + 2^-9)> = 4.665707820095122e-10
# smaller than this will become 0
dl16_max = 8581545984.0
dl16_min = 4.665707820095122e-10
x = tl.where(x >= dl16_max, float("inf"), x)
x = tl.where(x <= -dl16_max, float("-inf"), x)
x = tl.where(tl.abs(x) < dl16_min, 0, x)
return x


def tl_matmul_chunk_truncate(
Expand All @@ -448,6 +475,7 @@ def tl_matmul_chunk_truncate(
max_acc_bits=32,
truncate_then_accumulate=True,
cast_output_to_input_dtype=None,
clamp_acc_to_dl16=False,
):
"""Triton matmul for HW behavior simulation. Supports float and int8.
i. variable chunk size (i.e., BLOCK_SIZE_K)
Expand All @@ -461,7 +489,8 @@ def tl_matmul_chunk_truncate(
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
clamp each chunk of a*b to [-2**23-1, 2**23].
(assuming no inf when overflow)
(only used by INT)
clamp_acc_to_dl16(bool): Only used by FP8, whether to clamp local accumulator (FP32) to DL16
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
c = truncate(a*b+c)
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
Expand All @@ -473,7 +502,7 @@ def tl_matmul_chunk_truncate(

NOTE:
use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
real model inference. otherwise auto-tune will be triggered in every forward call.
real model inference. otherwise auto-tune may be triggered in every forward call.
"""

# Check constraints.
Expand Down Expand Up @@ -584,6 +613,7 @@ def grid(META):
c.stride(1),
chunk_trun_bits=chunk_trun_bits,
max_acc_bits=max_acc_bits,
clamp_acc_to_dl16=clamp_acc_to_dl16,
truncate_then_accumulate=truncate_then_accumulate,
ACTIVATION=activation,
**kernel_config, # if using auto-tune, comment this line out.
Expand Down
57 changes: 46 additions & 11 deletions fms_mo/custom_ext_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,14 +870,16 @@ def lower_qmodel_triton(
model: torch.nn.Module,
use_dyn_max_act=False,
max_acc_bits=32,
clamp_acc_to_dl16=False,
num_lsb_to_truncate=0,
chunk_size=32,
layer_to_exclude=None,
):
"""
Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else.
Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers.
Triton kernel can be used to:
1. test INT8 or FP8 HW performance (kernel is not optimized)
2. simulate MSB/LSB truncation effect
2. simulate MSB/LSB truncation effect or special dtype (DL16) accumulation

Args:
model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy
Expand All @@ -888,6 +890,8 @@ def lower_qmodel_triton(
efficiency at the expense of higher chance of accumulation "overflow".
For example, an INT24 accumulator can only hold values ranged from -2^23 to
2^23 -1, as opposed to typical range -2^31 to -2^31 -1.
clamp_acc_to_dl16: clamp local accumulator to DL16 (1-6-9) range. To simulate this special
dtype effect on accumulation.
num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is
s1e8m23, if we choose to truncate 13 mantissa bits from right most side,
i.e. LSB, the resulting number will be s1e8m10, which is TF32.
Expand All @@ -900,25 +904,56 @@ def lower_qmodel_triton(
from torch.ao.quantization.utils import _parent_name

# Local
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
from fms_mo.modules.linear import LinearFPxAcc, QLinear, QLinearINT8Deploy

# Currently QLinearINT8 has more options in dynamic quantization than LinearFP. Here we resolve
# the differences as a patch solution (will unify the codes in future release)
linFP_dyn_code = (
"per_token"
if use_dyn_max_act in [-1, -2]
else "per_tensor"
if use_dyn_max_act
else False
)

if layer_to_exclude is None:
layer_to_exclude = []
elif isinstance(layer_to_exclude, str):
layer_to_exclude = [
layer_to_exclude,
]
elif not isinstance(layer_to_exclude, (list, tuple)):
raise RuntimeError("layer_to_exclude has to be either str, list, or tuple.")

for name, m in model.named_modules():
if not isinstance(m, QLinear):
if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude:
continue
parent_name, module_name = _parent_name(name)
parent_mod = model.get_submodule(parent_name)
qmod = getattr(parent_mod, module_name)
setattr(
parent_mod,
module_name,
QLinearINT8Deploy.from_fms_mo(
qmod,

# Only support simulations of 1) QLinear -> INT8, 2) nnLinear->FP8 for now
if isinstance(m, QLinear):
new_lin = QLinearINT8Deploy.from_fms_mo(
m,
use_int_kernel="triton",
use_dynamic_max_act_Qfunc=use_dyn_max_act,
max_acc_bits=max_acc_bits,
truncate_lsb=num_lsb_to_truncate,
chunk_size=chunk_size,
),
)
else:
new_lin = LinearFPxAcc.from_nn(
m,
trun_bits=num_lsb_to_truncate,
chunk_size=chunk_size,
dynamic_fp8=linFP_dyn_code,
clamp_acc_to_dl16=clamp_acc_to_dl16,
)

setattr(
parent_mod,
module_name,
new_lin,
)

logger.info(f"\nModel lowering with triton kernel is done.\n{model}")
Expand Down
32 changes: 18 additions & 14 deletions fms_mo/dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
# config layers to skip, smooth scale
config_quantize_smooth_layers(qcfg)

use_dynamo = True
# use dynamo as default unless really needed, False -> fallback to TorchScript tracing
if any(x != 32 for x in attn_bits):
logger.info("Quantize attention bmms or kvcache, will use dynamo for prep")
use_layer_name_pattern_matching = False
qcfg["qlayer_name_pattern"] = []
assert (
qcfg["qlayer_name_pattern"] == []
), "ensure nothing in qlayer_name_pattern when use dynamo"
use_dynamo = True
else:
logger.info("Attention bmms will not be quantized.")
use_layer_name_pattern_matching = True
use_dynamo = False

qcfg["seq_len"] = block_size
qcfg["model"] = model_args.model_name_or_path
Expand Down Expand Up @@ -216,17 +216,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
act_scales = get_act_scales(model, dq_dataloader, qcfg)
torch.save(act_scales, scale_file)

qmodel_prep(
model,
dq_dataloader,
qcfg,
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
use_dynamo=use_dynamo,
dev=dev,
save_fname="dq",
)
logger.info(f"Quantized model {model}")
logger.info("==" * 20)
if fms_mo_args.aiu_sim_triton != "fp8":
qmodel_prep(
model,
dq_dataloader,
qcfg,
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
use_dynamo=use_dynamo,
dev=dev,
save_fname="dq",
)
logger.info(f"Quantized model {model}")
logger.info("==" * 20)

if qcfg["smoothq"]:
logger.info("Starting to apply smooth scale")
Expand Down Expand Up @@ -260,12 +261,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
tokenizer.save_pretrained(opt_args.output_dir)

if fms_mo_args.aiu_sim_triton:
# NOTE plz apply correct HW settings here, defaults are not real HW params
lower_qmodel_triton(
model,
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
max_acc_bits=qcfg.get("max_acc_bits", 32),
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
chunk_size=qcfg.get("chunk_size", 1024),
chunk_size=qcfg.get("chunk_size", 32), # 1024
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
# layer_to_exclude=["lm_head",]
)

if fms_mo_args.eval_ppl:
Expand Down
12 changes: 9 additions & 3 deletions fms_mo/fx/dynamo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,14 +1180,20 @@ def cus_backend_model_analyzer(
if is_transformers:
# NOTE simplified method to determine 1st/last modules for transformers.
# will not work if model has multiple parallel heads at the end, e.g. obj det
def call_seq_hook(mod, *_args, **_kwargs):
qcfg["mod_call_seq"].append(lut_weight2modname[mod.weight])
def call_seq_hook(mod, *_args, **kwargs):
mod_name = kwargs.get("mod_name", lut_weight2modname.get(mod.weight, None))
if mod_name is None:
raise RuntimeError("cannot determine module name, plz check model.")

qcfg["mod_call_seq"].append(mod_name)

h_hooks = []
qcfg["mod_call_seq"] = []
for n, m in model.named_modules():
if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)):
h_hooks.append(m.register_forward_hook(call_seq_hook))
h_hooks.append(
m.register_forward_hook(partial(call_seq_hook, mod_name=n))
)

with torch.no_grad():
run_fwd_once(model, sample_inp)
Expand Down
4 changes: 2 additions & 2 deletions fms_mo/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,14 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False):
w_mat.numel() * w_mat.element_size()
+ b_mat.numel() * b_mat.element_size()
)
w_dtype = w_mat.dtype
w_dtype = str(w_mat.dtype)
w_shape = w_mat.shape

elif isinstance(w, torch.Tensor):
mem_use = w.numel() * w.element_size()
if hasattr(m, "bias") and m.bias is not None:
mem_use += m.bias.numel() * m.bias.element_size()
w_dtype = w.dtype
w_dtype = str(w.dtype)
w_shape = w.shape

if w_shape:
Expand Down
Loading
Loading