diff --git a/fms_mo/modules/bmm.py b/fms_mo/modules/bmm.py index ff4f026f..083d9651 100644 --- a/fms_mo/modules/bmm.py +++ b/fms_mo/modules/bmm.py @@ -82,7 +82,13 @@ def __init__( self.m2_bounded = m2_bounded self.qm1_mode = qm1_mode self.qm2_mode = qm2_mode - + self.smooth_attn= qcfg.get("smooth_attn", False) + self.smooth_attn_alpha = qcfg.get("smooth_attn_alpha", 0.5) + if self.smooth_attn_alpha < 0 or self.smooth_attn_alpha > 1: + raise ValueError( + "smooth_attn_alpha must be in range [0,1] " + f"(given: {self.smooth_attn_alpha})" + ) self.m1_clip_init_val = kwargs.get( "m1_clip_init_val", qcfg.get("m1_clip_init_val", 1.0) ) @@ -191,6 +197,12 @@ def forward(self, m1, m2): Returns: torch.Tensor: Output tensor after quantized bmm. """ + if self.smooth_attn: + attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5) + attn_scales = attn_scales.pow(self.smooth_attn_alpha) + m1 *= attn_scales + m2 /= attn_scales.reshape(1,1,m2.shape[2], 1) + # pylint: disable = access-member-before-definition if self.calib_counter: with torch.no_grad(): diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index 0d77f8b8..4e3220b1 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -123,12 +123,24 @@ def get_activation_quantizer( ) elif qa_mode == "dorefa": act_quantizer = dorefa_quantize_activation - elif ( - qa_mode == "max" - ): # NOTE Need to be careful using this for activation, particular to 1 sided. - act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) - elif qa_mode == "minmax": - act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) + + elif "max" in qa_mode: + # NOTE Need to be careful using this for activation, particular to 1 sided. + if "min" in qa_mode: + act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) + elif "pertoken" in qa_mode or "perToken" in qa_mode: + act_quantizer = QMaxDynamic(nbits, dim=-1) + elif "per_channel" in qa_mode or "perCh" in qa_mode: + act_quantizer = QMaxDynamic(nbits, dim=-2) + elif "sym" in qa_mode: + act_quantizer = Qmax( + nbits, + align_zero=True, + minmax=False, + extend_act_range=extend_act_range, + ) + else: + act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) elif qa_mode == "fix": act_quantizer = QFixSymmetric( nbits, init_clip_val=clip_val, align_zero=align_zero @@ -140,13 +152,7 @@ def get_activation_quantizer( minmax=False, extend_act_range=extend_act_range, ) - elif qa_mode == "pactsym": - act_quantizer = PACT2Sym( - nbits, - init_clip_val=clip_val, - dequantize=True, - inplace=False, - ) + elif qa_mode == "pactsym+": act_quantizer = PACTplusSym( nbits, @@ -179,8 +185,6 @@ def get_activation_quantizer( perToken=perToken, emulate=True, ) - elif qa_mode == "pertokenmax": - act_quantizer = PerTokenMax(nbits) else: raise ValueError(f"unrecognized activation quantization mode {qa_mode}") else: # swcap-compatible activation quantizers @@ -3488,6 +3492,42 @@ def __repr__(self): return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" +class QMaxDynamic(nn.Module): + def __init__(self, num_bits, dim=-1): + """ + For per-token or per-channel quantization using abs().max() as scale, usually for activation + and could be used for Qbmm M2 as well. + (reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token + dim = -2 -> per-channel + Zero is aligned so that the levels are symmetric around zero (lossing one level) + Since the token length is un-known before running, the quantizater can only calculate the + scales at the run times dynamically, meaning no trainable quantization scales is allowed. + (unless input seq length is always the same, not just padded to a fixed length.) + """ + super().__init__() + self.num_bits = num_bits + self.levels = 2 ** (self.num_bits - 1) - 1 + if isinstance(dim, str): + if "perCh" in dim or "per_channel" in dim: + dim = -2 + elif "perToken" in dim or "per_token" in dim or "per_Token" in dim: + dim = -1 + elif dim in [-1, -2]: + self.reduce_dim = dim + else: + raise ValueError( + f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}" + ) + + def forward(self, input_tensor): + amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0] + scales = amax_dim.clamp(min=1e-5).div(self.levels) + return input_tensor.div(scales).round().mul(scales) + + def __repr__(self): + return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" + + class Qdynamic(nn.Module): def __init__( self, diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index e7beafc6..6d713cfa 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -164,6 +164,8 @@ class FMSMOArguments(TypeChecker): bmm2_qm1_mode: str = field(default="pact", metadata={"help": ("bmm2.m1 quanitzer")}) bmm2_qm2_mode: str = field(default="pact", metadata={"help": ("bmm2.m1 quanitzer")}) smoothq_alpha: float = field(default=0.65, metadata={"help": "smooth quant alpha"}) + smooth_attn_alpha: float = field(default=0.5, metadata={"help": "smooth attention alpha"}) + smooth_attn: bool = field(default=False, metadata={"help": "enable smooth attention"}) qmodel_calibration: int = field( default=0, metadata={"help": "Num of batches for Qmodel calibration, using model copy."}, diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index e2c13355..f21cdc9b 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -149,6 +149,7 @@ def config_defaults() -> dict: "smoothq": False, "smoothq_scale_layers": [], "smoothq_act_scale_path": None, + "smooth_attn": False, # Other vars "which2patch_contextmanager": None, "force_stop_if_qbmm_auto_check_failed": False, @@ -940,11 +941,16 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: "pactsym+", "max", "minmax", + "maxbmm", "maxsym", "pertokenmax", "lsq+", "fix", "brecq", + ] + shared_modes = [ + "max_perToken", + "max_perCh", # fp8_e4m3 "fp8_e4m3_sat", "fp8_e4m3_scale", @@ -981,33 +987,34 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: "brecq", "adaround", "pertokenmax", - # fp8_e4m3 - "fp8_e4m3_sat", - "fp8_e4m3_scale", - "fp8_e4m3_sat_perCh", - "fp8_e4m3_scale_perCh", - "fp8_e4m3_sat_perToken", - "fp8_e4m3_scale_perToken", - # fp8_e5m2 - "fp8_e5m2_sat", - "fp8_e5m2_scale", - "fp8_e5m2_sat_perCh", - "fp8_e5m2_scale_perCh", - "fp8_e5m2_sat_perToken", - "fp8_e5m2_scale_perToken", + # # fp8_e4m3 + # "fp8_e4m3_sat", + # "fp8_e4m3_scale", + # "fp8_e4m3_sat_perCh", + # "fp8_e4m3_scale_perCh", + # "fp8_e4m3_sat_perToken", + # "fp8_e4m3_scale_perToken", + # # fp8_e5m2 + # "fp8_e5m2_sat", + # "fp8_e5m2_scale", + # "fp8_e5m2_sat_perCh", + # "fp8_e5m2_scale_perCh", + # "fp8_e5m2_sat_perToken", + # "fp8_e5m2_scale_perToken", ] bmm_mode_settings = [ "pact", "pactsym", "pactsym+", "maxsym", + "maxbmm", "max", "minmax", "pertokenmax", - "fp8_e4m3_sat", - "fp8_e4m3_scale_perToken", - "fp8_e5m2_sat", - "fp8_e5m2_scale_perToken", + # "fp8_e4m3_sat", + # "fp8_e4m3_scale_perToken", + # "fp8_e5m2_sat", + # "fp8_e5m2_scale_perToken", ] # Get strings in config for qa_modes, qw_modes, bmm_modes @@ -1043,7 +1050,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: # Check each for correct ranges for qa_mode_str in qa_modes_str: qa_mode = config.get(qa_mode_str, "pact+") - if not qa_mode in (qa_mode_settings + mx_spec_config_modes): + if not qa_mode in (qa_mode_settings + mx_spec_config_modes + shared_modes): raise ValueError( f"{qa_mode_str} = {qa_mode} is not set to one of the following: " f"{qa_mode_settings + mx_spec_config_modes}" @@ -1051,7 +1058,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: for qw_mode_str in qw_modes_str: qw_mode = config.get(qw_mode_str, "sawb+") - if not qw_mode in (qw_mode_settings + mx_spec_config_modes): + if not qw_mode in (qw_mode_settings + mx_spec_config_modes + shared_modes): raise ValueError( f"{qw_mode_str} = {qw_mode} is not set to one of the following: " f"{qw_mode_settings + mx_spec_config_modes}" @@ -1063,7 +1070,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: bmm_mode_consistency += bmm_mode.startswith("mx_") # mx_specs doesn't have 4 individual bmmX_qmY_modes, it re-uses w and a fmt instead. # We will keep them in qcfg (with "mx_" prefix NOT removed). - if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes): + if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes + shared_modes): raise ValueError( f"{bmm_mode_str} = {bmm_mode} is not set to one of the following: " f"{bmm_mode_settings + mx_spec_config_modes}" @@ -1101,6 +1108,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: "qskip_large_mag_layers", "recompute_narrow_weights", "smoothq", + "smooth_attn", ] for boolean_var_str in boolean_vars_str: boolean_var = config.get(