From 0f511730e27ba4a010b9a3552ee9ebc74b8cc796 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 11 Apr 2024 09:22:52 -0400 Subject: [PATCH] added gpt_bigcode 20b speculator variant --- fms_extras/models/speculator.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fms_extras/models/speculator.py b/fms_extras/models/speculator.py index 3df5238..8284a88 100644 --- a/fms_extras/models/speculator.py +++ b/fms_extras/models/speculator.py @@ -264,6 +264,13 @@ def flatten_batch(inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch. _llama_13b = {"emb_dim": 5120, "vocab_size": 32000, "n_predict": 3, "inner_dim": 4096} +_gpt_bigcode_20b = { + "emb_dim": 6144, + "vocab_size": 49152, + "n_predict": 4, + "inner_dim": 4096, +} + _architecture_name = "mlp_speculator" @@ -280,6 +287,11 @@ def factory(**user_kwargs): models.register_model( _architecture_name, "llama.13b", _mlp_speculator_factory_factory(_llama_13b) ) +models.register_model( + _architecture_name, + "gpt_bigcode.ibm.20b", + _mlp_speculator_factory_factory(_gpt_bigcode_20b), +) def _rename_hf_weights_to_fms(orig_sd):