-
Notifications
You must be signed in to change notification settings - Fork 640
Fix FP8 block scaling with sequence parallel #2637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3ba991b
390b2e1
637ba0f
9fe572b
19fd927
adfe33b
e6d1559
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1100,6 +1100,9 @@ def _start_all_gather_fp8_blockwise( | |
|
|
||
| # Fall back to high-precision all-gather if FP8 is not supported | ||
| if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) | ||
| out = quantizer(out) | ||
|
Comment on lines
1102
to
1108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-contiguous gather input In the new high-precision fallback ( |
||
|
|
@@ -1115,7 +1118,7 @@ def _start_all_gather_fp8_blockwise( | |
| "Input and quantizer do not have matching usages. " | ||
| "Dequantizing and requantizing to Float8BlockwiseQTensor." | ||
| ) | ||
| inp = quantizer(inp.dequantize()) | ||
| inp = quantizer(inp.dequantize(dtype=dtype)) | ||
|
|
||
| # Construct Float8BlockwiseQTensor output tensor | ||
| out = quantizer.make_empty(out_shape, dtype=dtype, device=device) | ||
|
|
@@ -1338,6 +1341,9 @@ def _all_gather_nvfp4( | |
| and quantizer is not None | ||
| and not quantizer.is_quantizable(inp) | ||
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty( | ||
| out_shape, | ||
| dtype=dtype, | ||
|
|
@@ -1358,7 +1364,7 @@ def _all_gather_nvfp4( | |
| "Input and quantizer do not have matching usages. " | ||
| "Dequantizing and requantizing to NVFP4." | ||
| ) | ||
| inp = quantizer(inp.dequantize()) | ||
| inp = quantizer(inp.dequantize(dtype=dtype)) | ||
|
|
||
| # Construct NVFP4 output tensor | ||
| out = quantizer.make_empty(out_shape, dtype=dtype, device=device) | ||
|
|
@@ -1505,6 +1511,9 @@ def _all_gather_mxfp8( | |
| and quantizer is not None | ||
| and not quantizer.is_quantizable(inp) | ||
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty( | ||
| out_shape, | ||
| dtype=dtype, | ||
|
|
@@ -1525,7 +1534,7 @@ def _all_gather_mxfp8( | |
| "Input and quantizer do not have matching usages. " | ||
| "Dequantizing and requantizing to MXFP8." | ||
| ) | ||
| inp = quantizer(inp.dequantize()) | ||
| inp = quantizer(inp.dequantize(dtype=dtype)) | ||
|
|
||
| # Construct MXFP8 output tensor | ||
| out = quantizer.make_empty(out_shape, dtype=dtype, device=device) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing
.contiguous()call oninpbefore all-gatherOther all-gather paths in this file use
.contiguous()(lines 1739, 1033). Non-contiguous tensors (from transpose/slicing) can cause runtime errors.