Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/examples/te_llama/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transformers==4.57.0
accelerate==1.10.0
peft==0.15.2
datasets==4.0.0
sentencepiece==0.2.1
15 changes: 10 additions & 5 deletions docs/examples/te_llama/te_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ def forward(self, hidden_states, *args, attention_mask, **kwargs):
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (
super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
),
# Handle case where hidden_states might be a tuple (from previous layer output)
# This can happen with older versions of HuggingFace transformers
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
Comment on lines +75 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is misleading about when tuple unpacking is needed.

With the new code that returns tensor directly, hidden_states should never be a tuple when called from HuggingFace's LlamaModel forward loop (in any version). The old code returned (tensor,) for transformers < 4.57, but HF's loop extracted it with layer_outputs[0] before passing to the next layer.

This check appears to be defensive programming rather than addressing a real backward compatibility scenario. Consider clarifying the comment to explain this is a safety check rather than expected behavior.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


# Return tensor directly for HuggingFace transformers >= 4.57
# (older versions wrapped output in tuple and extracted with layer_outputs[0])
return super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
)


Expand Down Expand Up @@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
layer_prefix_pat = r"model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
Expand Down
Loading
Loading