Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload

from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
Expand Down Expand Up @@ -143,7 +144,12 @@ def forward(
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
if isinstance(inp_view, Float8BlockwiseQTensor):
inputmats = inp_view.split_scaling_aware_fp8_transpose(
m_splits, input_quantizers
)
else:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
Expand Down Expand Up @@ -343,18 +349,28 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
# Unfused bias grad and multi-tensor quantize
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
if isinstance(grad_output_view, Float8BlockwiseQTensor):
grad_output = grad_output_view.split_scaling_aware_fp8_transpose(
ctx.m_splits, ctx.grad_output_quantizers
)
else:
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
else:
# Multi-tensor quantize
if isinstance(grad_output_view, Float8BlockwiseQTensor):
grad_output = grad_output_view.split_scaling_aware_fp8_transpose(
ctx.m_splits, ctx.grad_output_quantizers
)
else:
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
else:
# Multi-tensor quantize
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
elif ctx.debug:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
Expand Down Expand Up @@ -781,9 +797,10 @@ def forward(
"""
debug = self.is_debug_iter()

assert not isinstance(
inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8."
if not isinstance(inp, Float8BlockwiseQTensor):
assert not isinstance(
inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."

is_grad_enabled = torch.is_grad_enabled()
Expand Down
14 changes: 10 additions & 4 deletions transformer_engine/pytorch/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ def forward(
fake_dtype = inp.dtype
# blockwise scaling
if blockwise_recipe:
fp8_scale = inp._rowwise_scale_inv.T.contiguous()
if inp._rowwise_data.shape[0] == inp._rowwise_scale_inv.shape[0]:
fp8_scale = inp._rowwise_scale_inv
else:
fp8_scale = inp._rowwise_scale_inv.T.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
Expand Down Expand Up @@ -275,7 +278,7 @@ def forward(
shape=output.shape,
dtype=fake_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.T.contiguous(),
rowwise_scale_inv=permuted_scale,
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
Expand Down Expand Up @@ -423,7 +426,10 @@ def backward(ctx, unpermuted_act_grad):
unpermuted_act_grad = unpermuted_act_grad._data
# blockwise scaling
elif blockwise_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
if unpermuted_act_grad._rowwise_data.shape[0] == unpermuted_act_grad._rowwise_scale_inv.shape[0]:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv
else:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
Expand Down Expand Up @@ -485,7 +491,7 @@ def backward(ctx, unpermuted_act_grad):
shape=act_grad.shape,
dtype=fake_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.T.contiguous(),
rowwise_scale_inv=permuted_scale,
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
Expand Down
91 changes: 91 additions & 0 deletions transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
from ..triton.blockwise_scaling_aware_fp8_transpose import (
blockwise_scaling_aware_fp8_transpose,
)

aten = torch.ops.aten

Expand Down Expand Up @@ -437,6 +440,94 @@ def untyped_storage(self) -> torch.UntypedStorage:
return data.untyped_storage()
return torch.UntypedStorage(0, device=self.device)

def split_scaling_aware_fp8_transpose(self, m_splits, quantizers):
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "split_transpose_quantize only supports rowwise_data inputs."

# Temporary solution: perf fp8flow
assert (
self._rowwise_scale_inv.shape[0] == self._rowwise_data.shape[0]
), "rowwise_data and rowwise_scale_inv must have same M (rows)."
if (
self._is_gemm_ready_format()
and self._rowwise_data.shape[0] == self._rowwise_scale_inv.shape[0]
):
self._data_format = tex.Float8BlockScaleTensorFormat.COMPACT
assert (
not self._is_gemm_ready_format()
), "Only COMPACT input format is supported."

rowwise_usage = quantizers[0].rowwise_usage
device = self._rowwise_data.device
kept = [i for i, m in enumerate(m_splits) if m > 0]
m_splits_kept = [m_splits[i] for i in kept]

if len(m_splits_kept) > 0:
(
rowwise_data_list,
rowwise_scale_inv_t_list,
columnwise_data_list,
columnwise_scale_inv_list,
) = blockwise_scaling_aware_fp8_transpose(
self._rowwise_data, self._rowwise_scale_inv, m_splits_kept
)

if len(m_splits_kept) != len(m_splits):
K = self._rowwise_data.shape[1]
empty_rw_data = (
torch.empty((0, K), dtype=self._rowwise_data.dtype, device=device)
if rowwise_usage
else None
)
empty_rw_si_t = (
torch.empty(
(self._rowwise_scale_inv.shape[1], 0),
dtype=self._rowwise_scale_inv.dtype,
device=device,
)
if rowwise_usage
else None
)
empty_cw_data = torch.empty(
(K, 0), dtype=self._rowwise_data.dtype, device=device
)
empty_cw_si = torch.empty(
(0, K), dtype=self._rowwise_scale_inv.dtype, device=device
)

results = []
kept_idx = 0
for i, m in enumerate(m_splits):
if m == 0:
rowwise_data = empty_rw_data
rowwise_scale_inv_t = empty_rw_si_t
columnwise_data = empty_cw_data
columnwise_scale_inv = empty_cw_si
else:
rowwise_data = rowwise_data_list[kept_idx] if rowwise_usage else None
rowwise_scale_inv_t = (
rowwise_scale_inv_t_list[kept_idx] if rowwise_usage else None
)
columnwise_data = columnwise_data_list[kept_idx]
columnwise_scale_inv = columnwise_scale_inv_list[kept_idx]
kept_idx += 1

results.append(
Float8BlockwiseQTensorStorage(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv_t,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=self._fp8_dtype,
quantizer=quantizers[i],
is_2D_scaled=self._is_2D_scaled,
data_format=tex.Float8BlockScaleTensorFormat.GEMM_READY,
)
)

return results

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

Expand Down
Loading