[feature] support Gemma2Model for tensor parallem training#6122
[feature] support Gemma2Model for tensor parallem training#6122jing-4369 wants to merge 2 commits intohpcaitech:mainfrom
Conversation
for more information, see https://pre-commit.ci
|
Thanks for contributing! To add a new model, we will also need unit tests. Please reference the existing tests and feel free to ping other team members. |
| attn_kwargs: torch.Tensor = self._update_causal_mask( | ||
| attention_mask, hidden_states, cache_position, past_key_values, output_attentions | ||
| ) |
There was a problem hiding this comment.
We don't need this? The main branch seems to work
There was a problem hiding this comment.
this can be removed here.
but this is another bug, this did not work when you train llama3, llama3.1, llama3.2
https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/llama/benchmark.py
i hope you can try this, and use HybridParallelPlugin
There was a problem hiding this comment.
I'm not sure what you refer to, colossalai run --nproc_per_node 2 --master_port 29501 benchmark.py -p 3d -b 1 -g --zero 2 (flash attn disabled, so go into this if branch) doesn't throw any error.
Are you using the right transformers version?
To justify such changes and save time, please provide a command to easily reproduce the error.
📌 Checklist before creating the PR
[doc/gemini/tensor/...]: A concise descriptionpip install pre-commit && pre-commit install🚨 Issue number
fixed #6120
📝 What does this PR do?
support Gemma2Model for tensor parallem training
Attached here is a small bug fix to successfully run the llama model
💥 Checklist before requesting a review
⭐️ Do you enjoy contributing to Colossal-AI?