diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index 0d77f8b8..4b61eb91 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -123,23 +123,28 @@ def get_activation_quantizer( ) elif qa_mode == "dorefa": act_quantizer = dorefa_quantize_activation - elif ( - qa_mode == "max" - ): # NOTE Need to be careful using this for activation, particular to 1 sided. - act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) - elif qa_mode == "minmax": - act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) + + elif "max" in qa_mode: + # NOTE Need to be careful using this for activation, particular to 1 sided. + if "min" in qa_mode: + act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) + elif "pertoken" in qa_mode or "perToken" in qa_mode: + act_quantizer = QMaxDynamic(nbits, dim=-1) + elif "per_channel" in qa_mode or "perCh" in qa_mode: + act_quantizer = QMaxDynamic(nbits, dim=-2) + elif "sym" in qa_mode: + act_quantizer = Qmax( + nbits, + align_zero=True, + minmax=False, + extend_act_range=extend_act_range, + ) + else: + act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) elif qa_mode == "fix": act_quantizer = QFixSymmetric( nbits, init_clip_val=clip_val, align_zero=align_zero ) - elif qa_mode == "maxsym": - act_quantizer = Qmax( - nbits, - align_zero=True, - minmax=False, - extend_act_range=extend_act_range, - ) elif qa_mode == "pactsym": act_quantizer = PACT2Sym( nbits, @@ -179,8 +184,6 @@ def get_activation_quantizer( perToken=perToken, emulate=True, ) - elif qa_mode == "pertokenmax": - act_quantizer = PerTokenMax(nbits) else: raise ValueError(f"unrecognized activation quantization mode {qa_mode}") else: # swcap-compatible activation quantizers @@ -3488,6 +3491,42 @@ def __repr__(self): return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" +class QMaxDynamic(nn.Module): + def __init__(self, num_bits, dim=-1): + """ + For per-token or per-channel quantization using abs().max() as scale, usually for activation + and could be used for Qbmm M2 as well. + (reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token + dim = -2 -> per-channel + Zero is aligned so that the levels are symmetric around zero (lossing one level) + Since the token length is un-known before running, the quantizater can only calculate the + scales at the run times dynamically, meaning no trainable quantization scales is allowed. + (unless input seq length is always the same, not just padded to a fixed length.) + """ + super().__init__() + self.num_bits = num_bits + self.levels = 2 ** (self.num_bits - 1) - 1 + if isinstance(dim, str): + if "perCh" in dim or "per_channel" in dim: + dim = -2 + elif "perToken" in dim or "per_token" in dim or "per_Token" in dim: + dim = -1 + elif dim in [-1, -2]: + self.reduce_dim = dim + else: + raise ValueError( + f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}" + ) + + def forward(self, input_tensor): + amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0] + scales = amax_dim.clamp(min=1e-5).div(self.levels) + return input_tensor.div(scales).round().mul(scales) + + def __repr__(self): + return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" + + class Qdynamic(nn.Module): def __init__( self,