-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[bugfix][NPU]:Fix bug that correctly obtains device type #1229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads): | |
| sp_rank = get_sequence_parallel_rank() | ||
| freqs = pad_freqs(freqs, s_per_rank * sp_size) | ||
| freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] | ||
| freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank | ||
| freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the other file, while this change is correct, there's a potential mixed-precision issue on NPU devices.
During the multiplication To fully leverage |
||
| x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) | ||
| return x_out.to(x.dtype) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this change correctly fixes the device type check, there might be an underlying issue with the mixed-precision logic on NPU devices.
On line 94,
x_outis created as acomplex128tensor (fromtorch.float64). With this change,freqsis converted tocomplex64on NPU devices.When
x_out * freqsis executed on line 97, PyTorch's type promotion will likely castfreqsback tocomplex128to matchx_out's dtype. This could render the conversion tocomplex64ineffective. Ifcomplex128operations are unsupported on NPU, this could still lead to errors.For this to work as intended (using
complex64on NPU),x_outshould probably also be created ascomplex64. This would involve changingx.to(torch.float64)on line 94 to usetorch.float32when on an NPU device.