diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/__init__.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/__init__.py index 72383c1d..6c0e3c65 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/__init__.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/__init__.py @@ -22,6 +22,7 @@ from .gpt_bigcode import GPTBigCodeGPTQ from .gpt_neox import GPTNeoXGPTQ from .granite import GraniteGPTQ +from .granitemoe import GraniteMoeGPTQ from .llama import LlamaGPTQ from .mistral import MistralGPTQ from .mixtral import MixtralGPTQ diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/_const.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/_const.py index 23c4baa3..087dd034 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/_const.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/_const.py @@ -28,6 +28,7 @@ "granite", "gemma", "dbrx_converted", + "granitemoe", ] EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/auto.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/auto.py index 23a61a87..d0caf2a4 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/auto.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/auto.py @@ -29,6 +29,7 @@ from .gpt_bigcode import GPTBigCodeGPTQ from .gpt_neox import GPTNeoXGPTQ from .granite import GraniteGPTQ +from .granitemoe import GraniteMoeGPTQ from .llama import LlamaGPTQ from .mistral import MistralGPTQ from .mixtral import MixtralGPTQ @@ -43,6 +44,7 @@ "granite": GraniteGPTQ, "dbrx": DbrxGPTQ, "dbrx_converted": DbrxConvertedGPTQ, + "granitemoe": GraniteMoeGPTQ, } at_least_one_cuda_v6 = any( diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py index 07a7f772..3ee3f4fb 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py @@ -15,7 +15,8 @@ ############################################################################### # Standard from os.path import isfile, join -from typing import Dict, List, Optional, Union +from types import MethodType +from typing import Callable, Dict, List, Optional, Tuple, Union import copy import json import logging @@ -74,6 +75,7 @@ move_to, nested_move_to, pack_model, + replace_3d_parameters_with_module_list, simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes, @@ -94,6 +96,12 @@ class BaseGPTQModel(nn.Module): # does not include the node which holds all the repeating layers base_modules: List[str] = None + # If 3D Parameters to be converted + convert3dparameters: bool = False + + # User provided forward pass to replace the existing forward pass + update_forwards: List[Tuple[str, Callable]] = None + # name of lm_head lm_head: str = "lm_head" @@ -128,6 +136,13 @@ def __init__( super().__init__() self.model = model + if self.convert3dparameters: + replace_3d_parameters_with_module_list(model) + for mod in model.modules(): + forward = self.update_forwards.get(mod.__class__.__name__) + if forward is not None: + mod.forward = MethodType(forward, mod) + self.model_type = self.model.config.model_type self._quantized = quantized self.quantize_config = quantize_config @@ -561,7 +576,7 @@ def save_quantized( self.quantize_config.meta_set_versionable( key=META_FIELD_QUANTIZER, value=META_QUANTIZER_GPTQMODEL, - version=__version__, + version="1.0.0", ) # The config, quantize_config and model may be edited in place in save_quantized. diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/granitemoe.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/granitemoe.py new file mode 100644 index 00000000..87eb3103 --- /dev/null +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/granitemoe.py @@ -0,0 +1,55 @@ +############################################################################### +# Adapted from https://github.com/ModelCloud/GPTQModel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### +# Third Party +import torch + +# Local +from .base import BaseGPTQModel + + +def new_forward(self, inputs, expert_size): + """ + Forward pass of the GraniteMoeParallelExperts module. + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + Returns: + Tensor: Output tensor. + """ + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for i in range(self.num_experts): + # the key is we need to use call the module + output_list.append(self.weight[i](input_list[i])) + results = torch.cat(output_list, dim=0) + return results + + +class GraniteMoeGPTQ(BaseGPTQModel): + base_modules = ["model.embed_tokens", "model.norm"] + convert3dparameters = True + update_forwards = {"GraniteMoeParallelExperts": new_forward} + + layers_node = "model.layers" + layer_type = "GraniteMoeDecoderLayer" + layer_modules = [ + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], + ["self_attn.o_proj"], + [f"block_sparse_moe.input_linear.weight.{i}" for i in range(40)], + [f"block_sparse_moe.output_linear.weight.{i}" for i in range(40)], + ] diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/model.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/model.py index d51e0e60..e6b0c289 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/model.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/model.py @@ -715,3 +715,33 @@ def get_moe_layer_modules(layer_modules: List, num_experts: int) -> List: new_inside_layer_modules[-1].append(n) return new_inside_layer_modules + + +def replace_3d_parameters_with_module_list( + model: torch.nn.Module, +): + + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if len(param.shape) == 3: + device = param.device + dtype = param.dtype + num, in_features, out_features = param.shape + + module_list = [] + for i in range(num): + linear = torch.nn.Linear( + in_features=in_features, + out_features=out_features, + device=device, + dtype=dtype, + bias=None, # FIXME: how to support bias? + ) + linear.weight.data = param.data[i] + module_list.append(linear) + + module_list = torch.nn.ModuleList(module_list) + + # replace + delattr(module, param_name) + setattr(module, param_name, module_list)