From 54edfbc7079db38221ff0a94a8ba378395e15a5d Mon Sep 17 00:00:00 2001 From: BruceFeiFei Date: Mon, 16 Dec 2024 11:10:57 +0800 Subject: [PATCH] =?UTF-8?q?Telechat2=E6=A8=A1=E5=9E=8B=E6=B3=A8=E5=86=8C?= =?UTF-8?q?=EF=BC=8C=E5=AE=8C=E6=88=90MindIE=E6=9C=8D=E5=8A=A1=E5=8C=96?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mindformers/models/__init__.py | 6 + .../mindformers/models/telechat2/__init__.py | 5 + .../mindformers/models/telechat2/telechat.py | 501 ++++++++++++++++++ .../models/telechat2/telechat_config.py | 210 ++++++++ .../telechat2/tokenization_telechat2.py | 224 ++++++++ 5 files changed, 946 insertions(+) create mode 100644 mindformers-telechat/mindformers/models/telechat2/__init__.py create mode 100644 mindformers-telechat/mindformers/models/telechat2/telechat.py create mode 100644 mindformers-telechat/mindformers/models/telechat2/telechat_config.py create mode 100644 mindformers-telechat/mindformers/models/telechat2/tokenization_telechat2.py diff --git a/mindformers-telechat/mindformers/models/__init__.py b/mindformers-telechat/mindformers/models/__init__.py index 260b4d0..35bf987 100644 --- a/mindformers-telechat/mindformers/models/__init__.py +++ b/mindformers-telechat/mindformers/models/__init__.py @@ -182,6 +182,11 @@ CogVLM2VideoLMModel, LlamaForCausalLMForCogVLM2Image ) +from .telechat2 import ( + TelechatForCausalLM, + Telechat2Tokenizer, + TelechatConfig +) from .eva02 import ( EVA02Config, EVAModel @@ -236,3 +241,4 @@ __all__.extend(multi_modal.__all__) __all__.extend(configuration_utils.__all__) __all__.extend(modeling_utils.__all__) +__all__.extend(telechat2.__all__) diff --git a/mindformers-telechat/mindformers/models/telechat2/__init__.py b/mindformers-telechat/mindformers/models/telechat2/__init__.py new file mode 100644 index 0000000..6d4ec90 --- /dev/null +++ b/mindformers-telechat/mindformers/models/telechat2/__init__.py @@ -0,0 +1,5 @@ +from .telechat import TelechatForCausalLM +from .tokenization_telechat2 import Telechat2Tokenizer +from .telechat_config import TelechatConfig + +__all__ = ['TelechatForCausalLM', 'Telechat2Tokenizer', 'TelechatConfig'] \ No newline at end of file diff --git a/mindformers-telechat/mindformers/models/telechat2/telechat.py b/mindformers-telechat/mindformers/models/telechat2/telechat.py new file mode 100644 index 0000000..751aafd --- /dev/null +++ b/mindformers-telechat/mindformers/models/telechat2/telechat.py @@ -0,0 +1,501 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat models' APIs.""" +import copy +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import Tensor, nn, mint +from mindspore.context import ParallelMode +from mindspore.ops import operations as P +from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation + +from mindformers.core.loss.loss import CrossEntropyLoss +from mindformers.models.modeling_utils import PreTrainedModel +from mindformers.models.utils import LayerSetting, lazy_inline, check_fine_grain_interleave_valid +from mindformers.models.llama.llama_layer import LlamaRMSNorm +from mindformers.modules.layers import Linear, FreqsMgr, Dropout +from mindformers.modules.transformer import LowerTriangularMaskWithDynamic +from mindformers.modules.transformer.op_parallel_config import _check_config +from mindformers.tools.logger import logger +from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister +from mindformers.tools.utils import get_disable_custom_fa, get_predict_run_mode, get_use_rope_self_define + +from research.telechat2.telechat_transformer import TelechatDecodeLayer +from research.telechat2.telechat_interleave import TelechatDecodeLayerInterleave +from research.telechat2.telechat_layer import TelechatEmbedding +from research.telechat2.telechat_config import TelechatConfig + + +class TelechatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TelechatConfig + base_model_prefix = "telechat" + + +class TelechatModel(TelechatPreTrainedModel): + r""" + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TelechatDecoderLayer`] + Args: + config(TelechatConfig): the config of network + + Returns: + output: Tensor, the output of telechat decoderlayer + + Examples: + >>> from mindformers import TelechatModel + >>> network = TelechatModel.from_pretrained('telechat_115b') + >>> type(network) + + """ + + def __init__(self, + config: TelechatConfig = None): + super().__init__(config, auto_prefix=True) + _check_config(config.parallel_config) + self.dtype = config.compute_dtype + self.hidden_size = config.hidden_size + self.num_layers = config.num_layers + self.n_head = config.num_heads + self.head_dim = self.hidden_size // self.n_head + self.pad_token_id = config.pad_token_id + self.is_first_iteration = True + self.use_past = config.use_past + self.use_flash_attention = config.use_flash_attention + + self.embed_dropout_prob = config.embed_dropout_prob + self.embeddings_dropout = Dropout(1 - self.embed_dropout_prob) + + self.concat = P.Concat(-1) + self.cast = P.Cast() + self.shape = P.Shape() + self.reshape = P.Reshape() + # default open internal kernel boost + self.use_rope_self_define = get_use_rope_self_define() + self.disable_custom_fa = get_disable_custom_fa() + logger.info("disable custom flash attention score op:{}".format(self.disable_custom_fa)) + if self.disable_custom_fa: + self.prefill_flatten_mask = Tensor(np.triu(np.ones(shape=(128, 128), dtype=np.float16), 1)) + + self.freqs_mgr = FreqsMgr(head_dim=self.head_dim, + seq_length=config.seq_length, + max_position_embedding=config.max_position_embedding, + rotary_dtype=config.rotary_dtype, + theta=config.theta, + scaling_factor=config.scaling_factor, + extend_method=config.extend_method, + parallel_config=config.parallel_config) + self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length, + compute_type=config.compute_dtype, + is_dynamic=config.is_dynamic, + pad_token_id=config.pad_token_id, + use_flash_attention=config.use_flash_attention, + use_attn_mask_compression=config.use_attn_mask_compression) + self.tok_embeddings = TelechatEmbedding(vocab_table_size=config.vocab_size, + sigma=config.sigma, + mean=config.mean, + embedding_size=config.hidden_size, + param_init_type=config.embedding_init_type, + parallel_optimizer=config.parallel_optimizer) + self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave, + config.parallel_config) + self.layers = nn.CellList() + self.layer_setting = LayerSetting(config.num_layers, + config.offset, + config.parallel_config, + config.pp_interleave_num) + for layer_id in range(config.num_layers): + if self.fine_grain_interleave: + layer = TelechatDecodeLayerInterleave(config.seq_length, + layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + num_layers=config.num_layers, + n_kv_heads=config.n_kv_heads, + hidden_dropout_prob=config.hidden_dropout_prob, + attention_dropout_prob=config.attention_dropout_prob, + intermediate_size=config.intermediate_size, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + qkv_has_bias=config.qkv_has_bias, + wo_has_bias=config.wo_has_bias, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + rotary_dtype=config.rotary_dtype, + param_init_type=config.param_init_type, + res_dtype=config.res_dtype, + use_flash_attention=config.use_flash_attention, + is_dynamic=config.is_dynamic, + use_rope_slice=config.use_rope_slice, + fine_grain_interleave=config.fine_grain_interleave, + parallel_config=config.parallel_config) + else: + layer = TelechatDecodeLayer(layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + n_kv_heads=config.n_kv_heads, + sigma=config.sigma, + mean=config.mean, + hidden_dropout_prob=config.hidden_dropout_prob, + attention_dropout_prob=config.attention_dropout_prob, + intermediate_size=config.intermediate_size, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + qkv_has_bias=config.qkv_has_bias, + wo_has_bias=config.wo_has_bias, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + rotary_dtype=config.rotary_dtype, + param_init_type=config.param_init_type, + res_dtype=config.res_dtype, + use_past=config.use_past, + use_flash_attention=config.use_flash_attention, + use_attn_mask_compression=config.use_attn_mask_compression, + block_size=config.block_size, + num_blocks=config.num_blocks, + is_dynamic=config.is_dynamic, + use_rope_slice=config.use_rope_slice, + parallel_config=config.parallel_config) + self.layer_setting(layer, layer_id) + self.layers.append(layer) + self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps, + compute_type=config.layernorm_compute_type) + dp = config.parallel_config.data_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.tok_embeddings.pipeline_stage = 0 + if config.parallel_config.pipeline_stage > 1: + self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1 + self.tok_embeddings.set_comm_fusion(2) + self.norm_out.set_comm_fusion(2) + else: + self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + + self.tok_embeddings.shard(config.parallel_config) + self.casual_mask.shard(config.parallel_config) + self.concat.shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) + if self.fine_grain_interleave: + self.norm_out.shard((dp, 1)) + else: + self.norm_out.shard((dp, 1, 1)) + + # pylint: disable=W0613 + def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None, + block_tables=None, slot_mapping=None, prefix_keys_values=None): + """ + Forward of telechat model. + + Args: + tokens: the tokenized inputs with datatype int32 + batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental + prediction. Tensor of shape :math:`(batch_size,)`. Default None. + block_tables (Tensor[int64]): Store mapping tables for each sequence. + slot_mapping (Tensor[int32]): Store token cache physical slot index. + Returns: + output: Tensor, the output of telechat decoderlayer + """ + # preprocess + bs, seq_len = self.shape(tokens) + mask = None + if self.use_past: + if self.is_first_iteration: + if self.use_rope_self_define: + freqs_cis = self.freqs_mgr(seq_len) + else: + freqs_cis = self.freqs_mgr.prefill(bs, seq_len) + + if self.use_flash_attention: + if self.disable_custom_fa: # only support fp16 + mask = self.prefill_flatten_mask + freqs_cis = self.freqs_mgr.prefill_flatten() + else: + mask = self.casual_mask(tokens) # mask: [bs, seq, seq] + + if prefix_keys_values is not None: + if mask is None: + mask = self.casual_mask(tokens) + prefix_length = prefix_keys_values[0].shape[2] + prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype) + mask = self.concat((prefix_mask, mask)) + else: + freqs_cis = self.freqs_mgr.increment(batch_valid_length) + else: + mask = self.casual_mask(tokens) + freqs_cis = self.freqs_mgr(seq_len) + if prefix_keys_values is not None: + prefix_length = prefix_keys_values[0].shape[2] + prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype) + mask = self.concat((prefix_mask, mask)) + + # tokens: [bs, seq/1] + h = self.tok_embeddings(tokens) + h = self.embeddings_dropout(h) + h = self.reshape(h, (bs, seq_len, self.hidden_size)) + # h: [bs, seq/1, hidden_dim] + for i in range(self.num_layers): + prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None + h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables, + slot_mapping=slot_mapping, prefix_keys_values=prefix_kv) + output = self.norm_out(h) + return output + + +@MindFormerRegister.register(MindFormerModuleType.MODELS) +class TelechatForCausalLM(TelechatPreTrainedModel): + """ + Provide telechat training loss or logits through network. + + Args: + config (TelechatConfig): The config of telechat model. + + Returns: + output: Tensor, the output of telechat decoderlayer + + Examples: + >>> from mindformers.models.telechat import TelechatConfig, TelechatForCausalLM + >>> config = TelechatConfig(batch_size=2) + >>> network = TelechatForCausalLM(config=config) + >>> type(network) + + >>> from mindformers import TelechatForCausalLM + >>> network = TelechatForCausalLM.from_pretrained('telechat_115b') + >>> type(network) + + """ + + @lazy_inline + def __init__(self, config: TelechatConfig = None): + super(TelechatForCausalLM, self).__init__(config, auto_prefix=True) + _check_config(config.parallel_config) + self.config = config + self.ignore_token_id = config.ignore_token_id + self.pad_token_id = config.pad_token_id + self.use_past = config.use_past + self.vocab_size = config.vocab_size + self.is_first_iteration = True + self.disable_custom_fa = get_disable_custom_fa() + + self.shape = P.Shape() + self.reshape = P.Reshape() + self.cast = P.Cast() + self.slice = P.StridedSlice() + self.not_equal = P.NotEqual() + self.mul = P.Mul() + self.add = P.Add() + self.ones = P.Ones() + self.gather = P.Gather(1) + self.prefill_gather_flatten = P.Gather() + self.sub_batch_valid_len = P.Sub() + self.model = TelechatModel(config=config) + self.lm_head = Linear(in_channels=config.hidden_size, + out_channels=config.vocab_size, + has_bias=False, + compute_dtype=config.compute_dtype, + param_init_type=config.param_init_type, + weight_init="normal") # meta default: xavier_normal + + mp = config.parallel_config.model_parallel + vocab_size = config.vocab_size + loss_parallel_config = copy.deepcopy(config.parallel_config) + if vocab_size % mp != 0: + logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s", + vocab_size, mp) + logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") + loss_parallel_config.model_parallel = 1 + loss_parallel_config.data_parallel *= loss_parallel_config.context_parallel + check_for_nan_in_loss_and_grad = getattr(config, "check_for_nan_in_loss_and_grad", False) + self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config, + check_for_nan_in_loss_and_grad=check_for_nan_in_loss_and_grad) + + dp = config.parallel_config.data_parallel + mp = config.parallel_config.model_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.slice.shard(((dp, 1),)) + self.not_equal.shard(((dp, 1), ())) + self.mul.shard(((dp, 1), (dp, 1))) + self.add.shard(((dp, 1), ())) + self.gather.shard(((dp, 1, 1), (dp,))) + self.prefill_gather_flatten.shard(((dp, 1, 1), (dp,))) + self.sub_batch_valid_len.shard(((1,), ())) + if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0): + self.lm_head.shard(strategy_matmul=((dp, 1), (1, 1))) + else: + self.lm_head.shard(strategy_matmul=((dp, 1), (mp, 1))) + if config.parallel_config.pipeline_stage > 1: + self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 + + if config.quant == "w8a16": + logger.info("Using RoundToNearest to quant TelechatForCausalLM.") + from mindspore_gs.ptq import PTQConfig, PTQMode + from mindspore_gs.common import BackendTarget + from mindspore_gs.ptq import RoundToNearest as RTN + cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND) + ptq = RTN(config=cfg) + self.model = ptq.apply(self.model) + self.model = ptq.convert(self.model) + + self.predict_run_mode = get_predict_run_mode() + + logger.info("Predict run mode:{}".format(self.predict_run_mode)) + + def prepare_inputs_for_prefill_flatten(self, input_ids, batch_valid_length, slot_mapping, model_inputs): + """prepare inputs ids for prefill flatten""" + batch_valid_length_bs = batch_valid_length.shape[0] + input_ids_bs = input_ids.shape[0] + if batch_valid_length_bs == input_ids_bs and batch_valid_length_bs > 1: + input_ids_list = [] + for i in range(batch_valid_length_bs): + context_len = batch_valid_length[i] + input_ids_list.append(input_ids[i][:context_len]) + input_ids = np.concatenate(input_ids_list, 0) + input_ids = input_ids.reshape((1, -1)) + slot_mapping = np.delete(slot_mapping, np.where(slot_mapping == -1)) + model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32)) + model_inputs["slot_mapping"] = Tensor.from_numpy(slot_mapping) + return model_inputs + + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + """Return model inputs for generation""" + model_inputs = {} + if self.config.is_dynamic and "origin_inputs" in kwargs: + input_ids = kwargs["origin_inputs"] + model_inputs["input_ids"] = Tensor.from_numpy( + input_ids.astype(np.int32)) + + prefill = kwargs.get("prefill") + if self.disable_custom_fa and prefill: + batch_valid_length = kwargs.get("valid_length_each_example") + slot_mapping = kwargs.get("slot_mapping") + model_inputs = self.prepare_inputs_for_prefill_flatten(input_ids, batch_valid_length, slot_mapping, + model_inputs) + return model_inputs + + # pylint: disable=W0613 + def prepare_inputs_for_predict_layout(self, input_ids, **kwargs): + """Get Telechat model input tuple for transform ckpt.""" + input_ids = Tensor(input_ids, mstype.int32) + labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None + bs, seq = input_ids.shape[0], input_ids.shape[1] + slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32) + prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None + return input_ids, labels, None, None, None, None, None, None, None, None, None, slot_mapping, prefix_keys_values + + def set_dynamic_inputs(self, **kwargs): + """Set dynamic inputs""" + dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) + have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False) + if have_prefix_keys_values: + dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16) + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, + dynamic_batch_valid_length, None, None, dynamic_block_tables, + dynamic_slot_mapping, dynamic_prefix_keys_values) + else: + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, + dynamic_batch_valid_length, None, None, dynamic_block_tables, + dynamic_slot_mapping, None) + logger.info("Set dynamic input for telechat.") + + def add_flags_custom(self, is_first_iteration): + """Add customized attributes for specific cells in the model.""" + self.add_flags(is_first_iteration=is_first_iteration) + self.model.add_flags(is_first_iteration=is_first_iteration) + for layer in self.model.layers: + layer.add_flags(is_first_iteration=is_first_iteration) + layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration) + + # pylint: disable=W0613 + def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None, + input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None, + block_tables=None, slot_mapping=None, prefix_keys_values=None): + r""" + TelechatForCausalLM forward. + + Args: + input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. + labels(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. + input_position(Tensor): current position, used by model.predict. + position_ids(Tensor): Reserved param, not used. + attention_mask(Tensor): Reserved param, not used. + input_embeds(Tensor): Reserved param, not used. + init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and + past value parameter used in the incremental prediction. Default True. + batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental + prediction. Tensor of shape :math:`(batch_size,)`. Default None. + block_tables (Tensor[int64]): Store mapping tables for each sequence. + slot_mapping (Tensor[int32]): Store token cache physical slot index. + Returns: + Tensor: The loss or (logits, tokens, input_mask) of the network. + """ + bsz, seqlen = self.shape(input_ids) + if self.use_past: + if not isinstance(batch_valid_length, Tensor): + batch_valid_length = self.ones((bsz,), mstype.int32) + if self.training: + tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1)) + else: + tokens = input_ids + if batch_valid_length is not None: + batch_valid_length = self.reshape(batch_valid_length, (-1,)) + output = self.model(tokens, batch_valid_length, batch_index, zactivate_len, block_tables, \ + slot_mapping, prefix_keys_values) + pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None + if pre_gather: + if self.disable_custom_fa: + batch_valid_length = mint.cumsum(batch_valid_length, 0) + output = self.prefill_gather_flatten(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) + else: + output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) + logits = self.lm_head(output) + + input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) + if labels is None: + labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) + else: + if labels.ndim > 1: + if self.training: + labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1)) + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32) + input_mask = self.mul(input_mask, label_mask) + + if not self.training: + logits = self.cast(logits, mstype.float32) + if self.predict_run_mode: + logits = self.reshape(logits, (-1, logits.shape[-1])) + return logits + return logits, tokens, input_mask + + if logits.ndim > 2: + logits = self.reshape(logits, (-1, logits.shape[-1])) + logits = self.cast(logits, mstype.float32) + labels = self.reshape(labels, (-1,)) + input_mask = self.reshape(input_mask, (-1,)) + loss = self.loss(logits, labels, input_mask) + return loss + + def kvcache(self, layer_idx): + key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache + value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache + return key_cache, value_cache diff --git a/mindformers-telechat/mindformers/models/telechat2/telechat_config.py b/mindformers-telechat/mindformers/models/telechat2/telechat_config.py new file mode 100644 index 0000000..3522f14 --- /dev/null +++ b/mindformers-telechat/mindformers/models/telechat2/telechat_config.py @@ -0,0 +1,210 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat Config API.""" + +from typing import Optional, Union + +from mindspore._checkparam import args_type_check + +from mindformers.modules.transformer.transformer import default_transformer_config, \ + TransformerOpParallelConfig +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.models.configuration_utils import PretrainedConfig +from mindformers.models.utils import convert_mstype + + +@MindFormerRegister.register(MindFormerModuleType.CONFIG) +class TelechatConfig(PretrainedConfig): + """ + Telechat config class which defines the model size. + + Args: + batch_size (Optional[int]): batch size for input data, use in predict. + seq_length (Optional[int]): The sequence length of input_ids, default is 1024. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the BERT model. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + multiple_of (Optional[int]): Define SwiGLU hidden layer size multiples, default 256. + n_kv_heads (Optional[int]): Define multi group head attention heads number, default None. + ffn_dim_multiplier (Optional[int]): Define ffn layer dim multiples, default None. + rms_norm_eps (Optional[float]): The epsilon value of the denominator. Default 1e-5. + bos_token_id (Optional[int]): The id of the *beginning-of-sequence* token. + eos_token_id (Optional[int]): The id of the *end-of-sequence* token. + pad_token_id (Optional[int]): The id of the *padding* token. + ignore_token_id (Optional[int]): The id of the *ignoring* token. + compute_dtype (Optional[str]): + Linear layer compute dtype, default is "float16". + layernorm_compute_type (Optional[str]): + layernorm compute dtype, default is "float32". + softmax_compute_type (Optional[str]): + softmax compute dtype, default is "float32". + rotary_dtype (Optional[str]): + rope compute dtype, default is "float32". + param_init_type (Optional[str]): + parameter initial dtype, default is "float16". + qkv_has_bias (Optional[bool]): + Whether the Query, Key, and Value projection has bias. + use_past (`bool`, *optional*, defaults to `False`): + Whether the model should use the past last key/values attentions + (if applicable to the model) to speed up decoding. + parallel_config(TransformerOpParallelConfig): + The parallel configure. Default `default_transformer_config`, + an instance of `TransformerOpParallelConfig` with default args. + extend_method(str): The extend method of seq length of inferencem,default None. + use_flash_attention(bool): Whether enable flash attention ops, default False. + offset(int): Offset of transformer layer when set pipeline stage number. + checkpoint_name_or_path (Optional[str]): + checkpoint path or name used to load to the network. + repetition_penalty (`float`, *optional*, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + max_decode_length (`int`, *optional*, defaults to 1024): + The maximum length the generated tokens can have. Corresponds to the length of the input prompt + + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. + top_k (`int`, *optional*, defaults to 5): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*, defaults to 1.0): + If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to `top_p` or higher are kept for generation. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + block_size (`int`, *optional*, defaults to 16): + The maximum number of tokens in one block can have when using paged attention. + num_blocks (`int`, *optional*, defaults to 512): + The maximum number of blocks when using paged attention. + Returns: + Class, TelechatConfig. + """ + + model_type = "telechat" + + @args_type_check(parallel_config=(dict, TransformerOpParallelConfig)) + def __init__(self, + batch_size: int = 1, + seq_length: int = 2048, + hidden_size: int = 4096, + num_layers: int = 32, + num_heads: int = 32, + embed_dropout_prob: float = 0.0, + hidden_dropout_prob: float = 0.0, + attention_dropout_prob: float = 0.0, + n_kv_heads: Optional[int] = None, + max_position_embedding: Optional[int] = None, + intermediate_size: Optional[int] = None, + vocab_size: int = 32000, # defined later by tokenizer + multiple_of: int = 256, # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[int] = None, + rms_norm_eps: float = 1e-5, + bos_token_id: int = 1, + eos_token_id: int = 2, + pad_token_id: int = 0, + ignore_token_id: int = -100, + theta: float = 10000.0, + compute_dtype: str = "float16", + layernorm_compute_type: str = "float32", + softmax_compute_type: str = "float32", + rotary_dtype: str = "float32", + param_init_type: str = "float16", + embedding_init_type=None, + res_dtype: str = "float32", + qkv_has_bias: bool = False, + wo_has_bias: bool = True, + parallel_config: Union[dict, TransformerOpParallelConfig] = default_transformer_config, + use_past: bool = False, + extend_method: str = "None", + scaling_factor: float = 1.0, + is_dynamic: bool = False, + use_rope_slice: bool = False, + use_flash_attention: bool = False, + use_attn_mask_compression: bool = False, + parallel_optimizer: bool = False, + fine_grain_interleave: int = 1, + pp_interleave_num: int = 1, + offset: int = 0, + checkpoint_name_or_path: str = "", + repetition_penalty: float = 1.0, + max_decode_length: int = 1024, + block_size: int = 16, + num_blocks: int = 512, + top_k: int = 5, + top_p: float = 1.0, + do_sample: bool = True, + quant: str = "", + sigma: float = 0.0048, + mean: float = 0.0, + **kwargs): + super(TelechatConfig, self).__init__(**kwargs) + if isinstance(parallel_config, dict): + parallel_config = TransformerOpParallelConfig(**parallel_config) + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.embed_dropout_prob = embed_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_dropout_prob = attention_dropout_prob + self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length + self.intermediate_size = intermediate_size + self.multiple_of = multiple_of + self.n_kv_heads = n_kv_heads + self.ffn_dim_multiplier = ffn_dim_multiplier + self.rms_norm_eps = rms_norm_eps + self.wo_has_bias = wo_has_bias + self.param_init_type = convert_mstype(param_init_type) + if embedding_init_type is not None: + self.embedding_init_type = convert_mstype(embedding_init_type) + else: + self.embedding_init_type = self.param_init_type + self.qkv_has_bias = qkv_has_bias + self.layernorm_compute_type = convert_mstype(layernorm_compute_type) + self.softmax_compute_type = convert_mstype(softmax_compute_type) + self.rotary_dtype = convert_mstype(rotary_dtype) + self.compute_dtype = convert_mstype(compute_dtype) + self.res_dtype = convert_mstype(res_dtype) + self.parallel_config = parallel_config + self.checkpoint_name_or_path = checkpoint_name_or_path + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.ignore_token_id = ignore_token_id + self.use_past = use_past + self.extend_method = extend_method + self.scaling_factor = scaling_factor + self.is_dynamic = is_dynamic + self.use_rope_slice = use_rope_slice + self.use_flash_attention = use_flash_attention + self.use_attn_mask_compression = use_attn_mask_compression + self.parallel_optimizer = parallel_optimizer + self.fine_grain_interleave = fine_grain_interleave + self.offset = offset + self.repetition_penalty = repetition_penalty + self.max_decode_length = max_decode_length + self.pp_interleave_num = pp_interleave_num + self.top_k = top_k + self.top_p = top_p + self.do_sample = do_sample + self.sigma = sigma + self.mean = mean + self.theta = theta + self.block_size = block_size + self.num_blocks = num_blocks + self.quant = quant diff --git a/mindformers-telechat/mindformers/models/telechat2/tokenization_telechat2.py b/mindformers-telechat/mindformers/models/telechat2/tokenization_telechat2.py new file mode 100644 index 0000000..fb49d4a --- /dev/null +++ b/mindformers-telechat/mindformers/models/telechat2/tokenization_telechat2.py @@ -0,0 +1,224 @@ +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple +import sentencepiece as spm +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging +from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +# TODO: when we get download url from huggingface, refresh the map +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": {}, + "tokenizer_file": {}, +} + +@MindFormerRegister.register(MindFormerModuleType.TOKENIZER) +class Telechat2Tokenizer(PreTrainedTokenizer): + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="<_start>", + eos_token="<_end>", + pad_token="<_pad>", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + @property + def vocab(self): + return self.get_vocab() + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + # prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + # if not prev_is_special and i != 0: + # out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + # prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + # prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output