diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index c8bd9d47c31..897d9f683e2 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1290,6 +1290,59 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +class TestQuantizeWithVmap: + """Test vmap support for quantization primitives.""" + + @pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) + @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("q_layout", [QuantizeLayout.ROWWISE]) + def test_vmap_quantize(self, in_dtype, scaling_mode, q_layout): + """Test that vmap works with tex.quantize using the general batcher.""" + # Determine q_dtype based on scaling mode + if scaling_mode.is_nvfp4_scaling: + q_dtype = jnp.float4_e2m1fn + else: + q_dtype = jnp.float8_e4m3fn + + # Create batched input (E, M, K) - E experts + E, M, K = 4, 64, 128 + key = jax.random.PRNGKey(0) + batched_input = jax.random.uniform(key, (E, M, K), in_dtype) + + # Create per-expert quantizers + quantizers = [ + QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + ) + for _ in range(E) + ] + + # Stack quantizers for vmap + stacked_quantizers = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *quantizers) + + # Vmap over expert dimension + def quantize_single(x, quantizer): + return tex.quantize(x, quantizer=quantizer, flatten_axis=-1) + + vmapped_quantize = jax.vmap(quantize_single, in_axes=(0, 0)) + result = vmapped_quantize(batched_input, stacked_quantizers) + + # Verify shapes + assert result.data.shape == (E, M, K) + assert result.scale_inv.shape[0] == E # Per-expert scales + + # Compare with calling quantize for each expert individually + individual_results = [] + for i in range(E): + res_i = tex.quantize(batched_input[i], quantizer=quantizers[i], flatten_axis=-1) + individual_results.append(res_i.data) + + expected = jnp.stack(individual_results, axis=0) + assert_allclose(result.data, expected, dtype=quantizers[0].q_dtype) + + valid_fp8_gemm_operand_types = [ (jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e5m2, jnp.float8_e4m3fn), diff --git a/tests/jax/test_einsum.py b/tests/jax/test_einsum.py new file mode 100644 index 00000000000..7580a14638f --- /dev/null +++ b/tests/jax/test_einsum.py @@ -0,0 +1,221 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for TE einsum operation with FP8 quantization.""" + +import jax +import jax.numpy as jnp +import pytest +from jax import value_and_grad + +from utils import assert_allclose, pytest_parametrize_wrapper +from transformer_engine.jax.einsum import einsum +from transformer_engine.jax.quantize import ( + QuantizerFactory, + QuantizeMeta, + QuantizeMetaSet, +) +from transformer_engine.jax.quantize import helper + + +# Test parameters +DTYPES = [jnp.bfloat16] +# (B, S, M, E, C, H) +# B: Batch size +# S: Sequence length (number of tokens) +# M: Model dimension (hidden size) +# E: Number of experts +# C: Capacity (max tokens per expert) +# H: Hidden dimension (MLP intermediate size) +MOE_CASES = [ + (2, 32, 128, 4, 32, 64), +] + +# Get supported recipes +supported_recipes = helper.get_supported_quantization_recipes() +supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] + + +@pytest.fixture(autouse=True, scope="module") +def init(): + """WAR for CUDA uninitialize error""" + # Calling customcalls before jax may cause CUDA uninitialize error + _ = jnp.zeros(0) + yield + + +class TestMoEMLPWithRecipes: + """Test MoE MLP operations with different FP8 recipes and gradients.""" + + def _get_quantizer_sets(self, recipe, num_experts): + return QuantizerFactory.create_set( + n_quantizer_sets=num_experts, + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), + ) + + def _einsum(self, equation, *operands, quantizer_sets=None, quantizer_dim=None, fallback=False): + out = einsum( + equation, + *operands, + quantizer_sets=quantizer_sets, + quantizer_dim=quantizer_dim, + fallback=fallback, + ) + return jnp.mean(out) + + def _ref_einsum(self, equation, *operands): + out = jnp.einsum(equation, *operands) + return jnp.mean(out) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_mlp_up_grad(self, B, S, M, E, C, H, recipe): + """Test MLP up: EBCM,EMH->EBCH with gradients and different recipes.""" + # Create per-expert quantizers + quantizer_sets = self._get_quantizer_sets(recipe, E) + dispatched = jax.random.normal( + jax.random.PRNGKey(0), (E, B, C, M), dtype=jnp.bfloat16 + ) / jnp.sqrt(M) + weights = jax.random.normal(jax.random.PRNGKey(1), (E, M, H), dtype=jnp.bfloat16) + + # Compute with TE einsum with quantization + loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))( + "EBCM,EMH->EBCH", dispatched, weights, quantizer_sets=quantizer_sets, quantizer_dim="E" + ) + + # Compute reference (BF16) + loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))( + "EBCM,EMH->EBCH", dispatched, weights + ) + + # Verify shapes and no NaNs + assert grads_te[0].shape == dispatched.shape + assert grads_te[1].shape == weights.shape + assert not jnp.isnan(loss_te) + assert jnp.all(jnp.isfinite(grads_te[0])) + assert jnp.all(jnp.isfinite(grads_te[1])) + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_mlp_down_grad(self, B, S, M, E, C, H, recipe): + """Test MLP down: EBCH,EHM->EBCM with gradients and different recipes.""" + # Create per-expert quantizers + quantizer_sets = self._get_quantizer_sets(recipe, E) + + hidden = jax.random.normal( + jax.random.PRNGKey(0), (E, B, C, H), dtype=jnp.bfloat16 + ) / jnp.sqrt(H) + weights = jax.random.normal(jax.random.PRNGKey(1), (E, H, M), dtype=jnp.bfloat16) + + # Compute with TE einsum with quantization + loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))( + "EBCH,EHM->EBCM", hidden, weights, quantizer_sets=quantizer_sets, quantizer_dim="E" + ) + + # Compute reference (BF16) + loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))( + "EBCH,EHM->EBCM", hidden, weights + ) + + # Verify shapes and no NaNs + assert grads_te[0].shape == hidden.shape + assert grads_te[1].shape == weights.shape + assert not jnp.isnan(loss_te) + assert jnp.all(jnp.isfinite(grads_te[0])) + assert jnp.all(jnp.isfinite(grads_te[1])) + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_full_moe_grad(self, B, S, M, E, C, H, recipe): + """Test full MoE pipeline (all 4 einsums) with gradients and different recipes.""" + # Create per-expert quantizers for each einsum + mlp_up_quantizer_sets = self._get_quantizer_sets(recipe, E) + mlp_down_quantizer_sets = self._get_quantizer_sets(recipe, E) + + tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt( + M + ) + routing = jax.random.normal(jax.random.PRNGKey(1), (B, S, E, C), dtype=jnp.bfloat16) + routing = jax.nn.softmax(routing, axis=-1) # Normalize routing weights + up_weights = jax.random.normal( + jax.random.PRNGKey(2), (E, M, H), dtype=jnp.bfloat16 + ) / jnp.sqrt(H) + down_weights = jax.random.normal( + jax.random.PRNGKey(3), (E, H, M), dtype=jnp.bfloat16 + ) / jnp.sqrt(M) + + # TE implementation with quantization + def full_moe_te(tokens, routing, up_w, down_w): + """Complete MoE pipeline with TE einsum.""" + dispatched = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + hidden = einsum( + "EBCM,EMH->EBCH", + dispatched, + up_w, + quantizer_sets=mlp_up_quantizer_sets, + quantizer_dim="E", + ) + expert_out = einsum( + "EBCH,EHM->EBCM", + hidden, + down_w, + quantizer_sets=mlp_down_quantizer_sets, + quantizer_dim="E", + ) + output = einsum("EBCM,BSEC->BSM", expert_out, routing, fallback=True) + return jnp.sum(output) + + # Reference implementation with jnp.einsum + def full_moe_ref(tokens, routing, up_w, down_w): + """Complete MoE pipeline with jnp.einsum.""" + dispatched = jnp.einsum("BSM,BSEC->EBCM", tokens, routing) + hidden = jnp.einsum("EBCM,EMH->EBCH", dispatched, up_w) + expert_out = jnp.einsum("EBCH,EHM->EBCM", hidden, down_w) + output = jnp.einsum("EBCM,BSEC->BSM", expert_out, routing) + return jnp.sum(output) + + loss_te, grads_te = value_and_grad(full_moe_te, argnums=(0, 1, 2, 3))( + tokens, routing, up_weights, down_weights + ) + + loss_ref, grads_ref = value_and_grad(full_moe_ref, argnums=(0, 1, 2, 3))( + tokens, routing, up_weights, down_weights + ) + + # Verify all gradient shapes + assert grads_te[0].shape == tokens.shape, f"tokens grad shape mismatch" + assert grads_te[1].shape == routing.shape, f"routing grad shape mismatch" + assert grads_te[2].shape == up_weights.shape, f"up_weights grad shape mismatch" + assert grads_te[3].shape == down_weights.shape, f"down_weights grad shape mismatch" + + # Verify no NaNs or Infs + assert not jnp.isnan(loss_te), "Loss is NaN" + assert jnp.isfinite(loss_te), "Loss is Inf" + assert jnp.all(jnp.isfinite(grads_te[0])), "tokens grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[1])), "routing grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[2])), "up_weights grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[3])), "down_weights grad has NaN/Inf" + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=mlp_up_quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[2], grads_ref[2], dtype=mlp_down_quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[3], grads_ref[3], dtype=mlp_down_quantizer_sets[0].dgrad.q_dtype) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py index afc248a0ad4..8d166717ba6 100644 --- a/transformer_engine/jax/cpp_extensions/amax.py +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -160,6 +160,18 @@ def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types output_spec = (f"{prefix}_amax",) return SdyShardingRule((input_spec,), (output_spec,)) + @staticmethod + def batcher(batched_args, batch_dims, *, amax_scope, transpose_batch_sequence): + """Batcher for amax calculation - returns single amax value.""" + return AmaxCalculationPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "amax_scope": amax_scope, + "transpose_batch_sequence": transpose_batch_sequence, + }, + ) + register_primitive(AmaxCalculationPrimitive, outer_only=True) @@ -370,6 +382,30 @@ def shardy_sharding_rule( output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",) return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec)) + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """Batcher for RHT amax calculation - returns 2 amax values.""" + return RHTAmaxCalculationPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "amax_scope": amax_scope, + "transpose_batch_sequence": transpose_batch_sequence, + "rht_matrix_random_sign_mask_t": rht_matrix_random_sign_mask_t, + "produce_regular_amax": produce_regular_amax, + "flatten_axis": flatten_axis, + }, + ) + register_primitive(RHTAmaxCalculationPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 61deab5b804..22a4b7dda40 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,13 +7,14 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial +from typing import Any, Sequence, Union, Tuple from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch -from jax import ffi +from jax import ffi, numpy as jnp import transformer_engine_jax @@ -168,6 +169,92 @@ def shardy_sharding_rule(*args): del args return "... -> ..." + @classmethod + def batcher_impl( + cls, + batched_args: Sequence[Any], + batch_dims: Sequence[Union[int, None]], + static_kwargs: dict, + ) -> Tuple[Tuple[Any, ...], Tuple[Union[int, None], ...]]: + """Batcher implementation for JAX primitives. + + Implements the standard batching pattern: loop over batch dimension, + call primitive for each slice, and stack results. + + Args: + batched_args: Tuple of input tensors (some may be batched) + batch_dims: Tuple indicating batch dimension for each arg (None if not batched) + static_kwargs: Dictionary of static arguments to pass to primitive.bind() + + Returns: + Tuple of (output_tensors, output_batch_dims) + + Example: + @staticmethod + def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): + return MyPrimitive.batcher_impl( + batched_args, batch_dims, + static_kwargs={'arg1': arg1, 'arg2': arg2, 'arg3': arg3}, + ) + """ + from jax import lax + + # Find batch dimension and validate all batched args have the same batch_dim + batch_dim = None + batch_size = None + for arg, bdim in zip(batched_args, batch_dims): + if bdim is not None: + if batch_dim is None: + batch_dim = bdim + batch_size = arg.shape[bdim] + elif bdim != batch_dim: + raise ValueError( + "All batched arguments must have the same batch dimension. " + f"Got batch_dims={batch_dims}" + ) + assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + + # Loop over batch dimension and collect results + all_results = [] + + for i in range(batch_size): + # Extract slice for each argument + sliced_args = [] + for arg, bdim in zip(batched_args, batch_dims): + if bdim is not None: + slice_i = lax.index_in_dim(arg, i, bdim, keepdims=False) + sliced_args.append(slice_i) + else: # For empty args + sliced_args.append(arg) + + # Call primitive with unbatched slices + result_i = cls.outer_primitive.bind(*sliced_args, **static_kwargs) + + # Normalize to tuple + if not isinstance(result_i, (tuple, list)): + result_i = (result_i,) + elif isinstance(result_i, list): + result_i = tuple(result_i) + + all_results.append(result_i) + + # Transpose: from list of tuples to tuple of lists + # all_results = [(out0_0, out1_0, ...), (out0_1, out1_1, ...), ...] + # transposed = ([out0_0, out0_1, ...], [out1_0, out1_1, ...], ...) + transposed = tuple(zip(*all_results)) + + # Stack each output along the batch dimension + stacked_results = tuple( + jnp.stack(list(out_list), axis=batch_dim) for out_list in transposed + ) + + # Single output: return unwrapped result + if len(stacked_results) == 1: + return stacked_results[0], batch_dim + + # Multiple outputs: return tuple of results + return stacked_results, [batch_dim for _ in stacked_results] + # Registry to store all registered primitive classes _primitive_registry = {} diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 76a8b225ba8..55a17008380 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -808,40 +808,33 @@ def batcher( sequence_dim, is_outer, ): - del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None lhs_bdims, _, rhs_bdims, *_ = batch_dims - # Batched GEMM is not supported - assert ( - lhs_bdims is None and rhs_bdims is None - ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})" - out_bdims = (None,) - - # Bias gradient is never batched - bias_bdims = (None,) - - # Pre-GeLU output, if exists, is batched like GEMM output - pre_gelu_bdims = (None,) - if fuse_gelu and not grad: - pre_gelu_bdims = out_bdims + # Validate batch dimensions + if lhs_bdims is not None or rhs_bdims is not None: + assert lhs_bdims == rhs_bdims, ( + "Batched GEMM requires matching batch dimensions, " + f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" + ) - return ( - GemmPrimitive.outer_primitive.bind( - *batched_args, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - scaling_mode=scaling_mode, - fuse_bias=fuse_bias, - fuse_gelu=fuse_gelu, - grad=grad, - use_split_accumulator=use_split_accumulator, - collective_op=collective_op, - transpose_batch_sequence=transpose_batch_sequence, - sequence_dim=sequence_dim, - is_outer=is_outer, - ), - (out_bdims, bias_bdims, pre_gelu_bdims), + # Use general batcher from BasePrimitive + return GemmPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "out_dtype": out_dtype, + "contracting_dims": contracting_dims, + "scaling_mode": scaling_mode, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + "collective_op": collective_op, + "transpose_batch_sequence": transpose_batch_sequence, + "sequence_dim": sequence_dim, + "is_outer": is_outer, + }, ) @staticmethod diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b3f24e9337e..53c6937fb4a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -361,34 +361,24 @@ def batcher( stochastic_rounding, use_rht, ): - """ - to describe batch rules for vmap - """ - del is_outer + """Batch rule for quantization primitive using general batcher.""" check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args - x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims - out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim - return ( - BaseDBiasQuantizePrimitive.outer_primitive.bind( - x, - scale, - amax, - sr_rng_state, - post_rht_amax, - rht_matrix, - out_dtype=out_dtype, - scaling_mode=scaling_mode, - q_layout=q_layout, - flatten_axis=flatten_axis, - scale_dtype=scale_dtype, - is_dbias=is_dbias, - stochastic_rounding=stochastic_rounding, - use_rht=use_rht, - ), - out_bdims, + return BaseDBiasQuantizePrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "out_dtype": out_dtype, + "scaling_mode": scaling_mode, + "q_layout": q_layout, + "flatten_axis": flatten_axis, + "scale_dtype": scale_dtype, + "is_dbias": is_dbias, + "is_outer": is_outer, + "stochastic_rounding": stochastic_rounding, + "use_rht": use_rht, + }, ) @staticmethod diff --git a/transformer_engine/jax/einsum.py b/transformer_engine/jax/einsum.py new file mode 100644 index 00000000000..20084c77ea1 --- /dev/null +++ b/transformer_engine/jax/einsum.py @@ -0,0 +1,424 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Einsum operation with FP8 quantization support for Transformer Engine in JAX. + +This module provides an einsum implementation that decomposes einsum operations into +a sequence of GEMMs, each with its own quantizer for FP8 support. It follows the +pattern of jax.numpy.einsum but uses TE's optimized GEMM operations. + +This module provides an einsum implementation optimized for Mixture-of-Experts (MoE) +models with per-expert quantization support. It leverages JAX's vmap and TE's dense +layer to efficiently handle tensor contractions with a single batch dimension. + +Key Features: + - **Per-expert quantization**: Each expert can have independent scaling and quantization parameters + - **Automatic differentiation**: Full gradient support via dense layer's VJP + - **Single batch dimension**: Optimized for MoE patterns (expert dimension) + - **Explicit API**: Requires quantizer_dim when using quantization + +Limitations: + - **NN layout only**: LHS last dim must contract, RHS last dim must not contract + - **Single batch dimension**: Only one batch dimension supported + - **2-operand only**: Only supports binary operations + - **Explicit quantizer_dim**: Required when quantizer_sets is provided + + For operations that don't meet these requirements (e.g., routing operations + like "BSM,BSEC->EBCM"), use jnp.einsum instead, or set fallback=True to + automatically fall back to jnp.einsum when the operation is not supported. + +Example - MoE Forward Pass with Per-Expert FP8: + ```python + from transformer_engine.jax.einsum import einsum + from transformer_engine.jax.quantize import QuantizerFactory, QuantizeMeta, QuantizeMetaSet + + # Create per-expert quantizers (E experts) + quantizer_sets = [ + QuantizerFactory.create_set( + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ) + ) for _ in range(num_experts) + ] + + # MoE pipeline with per-expert quantization, + # 1. Dispatch: BSM,BSEC -> EBCM (no quantization - routing operation) + dispatched = jnp.einsum("BSM,BSEC->EBCM", tokens, routing) + # Or with fallback: + # dispatched = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + + # 2. MLP Up: EBCM,EMH -> EBCH (per-expert quantization) + hidden = einsum("EBCM,EMH->EBCH", dispatched, expert_up_weights, + quantizer_sets=expert_quantizers, quantizer_dim='E') + + # 3. MLP Down: EBCH,EHM -> EBCM (per-expert quantization) + expert_out = einsum("EBCH,EHM->EBCM", hidden, expert_down_weights, + quantizer_sets=expert_quantizers, quantizer_dim='E') + + # 4. Combine: EBCM,BSEC -> BSM (no quantization - routing operation) + output = jnp.einsum("EBCM,BSEC->BSM", expert_out, routing) + # Or with fallback: + # output = einsum("EBCM,BSEC->BSM", expert_out, routing, fallback=True) + ``` + +Implementation Details: + The einsum function works by: + 1. Parsing the einsum equation to identify the single batch dimension and contracting dimensions + 2. Validating that quantizer_sets length matches the quantizer dimension size + 3. Creating a vmapped version of TE's dense layer over the batch dimension + 4. Vmapping over quantizer_sets to provide per-batch (e.g., per-expert) quantization + 5. Leveraging dense's existing VJP for automatic differentiation + + This design reuses TE's well-tested dense layer infrastructure while enabling + per-expert quantization for MoE models with minimal code complexity. +""" + +from typing import Tuple, Optional, List +import jax +import jax.numpy as jnp + +from .dense import dense +from .quantize import ( + QuantizerSet, + noop_quantizer_set, +) + + +def _parse_einsum_input(equation: str, *operands) -> Tuple[str, List[str], str]: + """Parse einsum equation into input specs and output spec. + + Args: + equation: Einsum equation string (e.g., "ij,jk->ik" or "BNSM,BNSEC->EBNCM") + operands: Input tensors + + Returns: + Tuple of (equation, input_specs, output_spec) + + Raises: + ValueError: If number of operands doesn't match equation + """ + # Remove spaces + equation = equation.replace(" ", "") + + if "->" in equation: + inputs_str, output_str = equation.split("->") + input_specs = inputs_str.split(",") + else: + # Implicit output mode + inputs_str = equation + input_specs = inputs_str.split(",") + # Compute implicit output + all_indices = set() + for spec in input_specs: + all_indices.update(spec) + output_str = "".join(sorted(all_indices)) + + # Validate each operand's ndim matches its spec + for i, (operand, spec) in enumerate(zip(operands, input_specs)): + expected_ndim = len(spec) + actual_ndim = operand.ndim + if actual_ndim != expected_ndim: + raise ValueError( + f"Operand {i} has {actual_ndim} dimensions but equation '{equation}' " + f"expects {expected_ndim} dimensions (spec: '{spec}'). " + f"Operand shape: {operand.shape}" + ) + + return equation, input_specs, output_str + + +def _find_contracting_and_batch_dims(lhs_spec: str, rhs_spec: str, output_spec: str): + """Find contracting and batch dimensions for a GEMM operation. + + Args: + lhs_spec: Index specification for LHS (e.g., "BNSM") + rhs_spec: Index specification for RHS (e.g., "BNSEC") + output_spec: Index specification for output (e.g., "EBNCM") + + Returns: + Tuple of (lhs_contracting, rhs_contracting, lhs_batch, rhs_batch) + """ + # Contracting dimensions: indices in both lhs and rhs but not in output + lhs_set = set(lhs_spec) + rhs_set = set(rhs_spec) + output_set = set(output_spec) + + contracting_indices = (lhs_set & rhs_set) - output_set + + # Batch dimensions: indices in lhs, rhs, and output + batch_indices = lhs_set & rhs_set & output_set + + # Find positions + lhs_contracting = tuple(i for i, c in enumerate(lhs_spec) if c in contracting_indices) + rhs_contracting = tuple(i for i, c in enumerate(rhs_spec) if c in contracting_indices) + lhs_batch = tuple(i for i, c in enumerate(lhs_spec) if c in batch_indices) + rhs_batch = tuple(i for i, c in enumerate(rhs_spec) if c in batch_indices) + + return lhs_contracting, rhs_contracting, lhs_batch, rhs_batch + + +def _einsum_to_gemm_info(equation: str, *operands): + """Extract GEMM information from einsum equation. + + Args: + equation: Einsum equation + operands: Input tensors + + Returns: + Dict with keys: lhs_idx, rhs_idx, contracting_dims, batch_dims, output_spec + """ + equation, input_specs, output_spec = _parse_einsum_input(equation, *operands) + + if len(input_specs) != 2: + raise NotImplementedError(f"Einsum with {len(input_specs)} operands not yet supported") + + lhs_spec, rhs_spec = input_specs + + lhs_contracting, rhs_contracting, lhs_batch, rhs_batch = _find_contracting_and_batch_dims( + lhs_spec, rhs_spec, output_spec + ) + + return { + "lhs_idx": 0, + "rhs_idx": 1, + "lhs_spec": lhs_spec, + "rhs_spec": rhs_spec, + "output_spec": output_spec, + "contracting_dims": (lhs_contracting, rhs_contracting), + "batch_dims": (lhs_batch, rhs_batch), + } + + +def einsum( + equation: str, + *operands: jnp.ndarray, + quantizer_sets: Optional[List[QuantizerSet]] = None, + quantizer_dim: Optional[str] = None, + operand_axes: Optional[List[Tuple[str, ...]]] = None, + output_axes: Optional[Tuple[str, ...]] = None, + fallback: bool = False, +) -> jnp.ndarray: + """Perform einsum operation with optional FP8 quantization using vmap + dense. + + This function implements einsum by: + 1. Identifying batch dimensions + 2. Using vmap to vectorize over batch dimensions + 3. Calling the existing dense() function which has VJP already implemented + + Each batched GEMM can have its own quantizer_set, enabling per-expert + quantization in MoE models. + + Args: + equation: Einsum equation string (e.g., "ij,jk->ik", "BSM,BSEC->EBCM") + *operands: Input tensors + quantizer_sets: List or tuple of QuantizerSets. Length must match the size of + the dimension specified by quantizer_dim. If None, creates noop quantizers. + quantizer_dim: Index label indicating which dimension the quantizers correspond to. + For MoE, this is typically 'E' (expert dimension). If None and + quantizer_sets is provided, assumes first batch dimension at position 0. + operand_axes: List of logical axes tuples for sharding each operand + output_axes: Logical axes for sharding the output + fallback: Whether to fallback to jnp.einsum if the einsum operation is not supported. + When fallback=True, unsupported operations (e.g., non-NN layouts, routing + operations) will use jnp.einsum. Note: quantization will NOT be applied + when falling back. + + Returns: + Result of the einsum operation + + Examples: + # Simple matrix multiplication with FP8 + result = einsum("ij,jk->ik", A, B, quantizer_sets=my_quantizer_set) + + # MoE with per-expert quantizers (E experts) + expert_quantizers = [quantizer_e0, quantizer_e1, ..., quantizer_eN] + result = einsum("EBNCM,EMH->EBNCH", tokens, weights, + quantizer_sets=expert_quantizers) + + # With fallback for routing operations + result = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + # Falls back to jnp.einsum (no quantization) + """ + if operand_axes is None: + operand_axes = [None] * len(operands) + + if len(operands) != 2: + if fallback: + import warnings + + warnings.warn( + f"TE einsum only supports 2-operand einsum, got {len(operands)} operands. " + "Falling back to jnp.einsum (no quantization will be applied).", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise NotImplementedError("Only 2-operand einsum currently supported") + + # Parse einsum to get GEMM info + gemm_info = _einsum_to_gemm_info(equation, *operands) + contracting_dims = gemm_info["contracting_dims"] + batch_dims = gemm_info["batch_dims"] + lhs_spec = gemm_info["lhs_spec"] + rhs_spec = gemm_info["rhs_spec"] + + lhs, rhs = operands + + # Validate quantizer_dim is provided when quantizer_sets is given + if quantizer_sets is not None and quantizer_dim is None: + raise ValueError( + "quantizer_dim must be specified when quantizer_sets is provided. " + "This explicitly indicates which dimension the quantizers correspond to." + ) + + # Find quantizer dimension + quantizer_dim_lhs = None + quantizer_dim_rhs = None + + if quantizer_dim is not None: + # Find position of quantizer_dim in lhs and rhs specs + if quantizer_dim in lhs_spec: + quantizer_dim_lhs = lhs_spec.index(quantizer_dim) + if quantizer_dim in rhs_spec: + quantizer_dim_rhs = rhs_spec.index(quantizer_dim) + + if quantizer_dim_lhs is None and quantizer_dim_rhs is None: + raise ValueError(f"quantizer_dim '{quantizer_dim}' not found in equation '{equation}'") + + # Check if we have batch dimensions + has_batch_dims = bool(batch_dims[0] or batch_dims[1]) + + # Determine expected quantizer_sets length based on quantizer_dim + if quantizer_dim is not None: + if quantizer_dim_lhs is not None: + expected_length = lhs.shape[quantizer_dim_lhs] + else: + expected_length = rhs.shape[quantizer_dim_rhs] + else: + # No quantizer_dim: determine from batch dimension + if has_batch_dims: + expected_length = lhs.shape[batch_dims[0][0]] + else: + expected_length = 1 + + # Validate and initialize quantizer_sets + if quantizer_sets is None: + quantizer_sets = [noop_quantizer_set] * expected_length + elif not isinstance(quantizer_sets, (list, tuple)): + raise TypeError(f"quantizer_sets must be a list or tuple, got {type(quantizer_sets)}") + elif len(quantizer_sets) != expected_length: + raise ValueError( + f"quantizer_sets length ({len(quantizer_sets)}) must match " + f"{'dimension ' + repr(quantizer_dim) if quantizer_dim else 'batch dimension'} " + f"size ({expected_length})" + ) + + # Validate that this is NN layout (required by dense) + # For NN: lhs last dim must contract, rhs last dim must NOT contract + lhs_ndim = len(gemm_info["lhs_spec"]) + rhs_ndim = len(gemm_info["rhs_spec"]) + lhs_last_contracts = lhs_ndim - 1 in contracting_dims[0] + rhs_last_contracts = rhs_ndim - 1 in contracting_dims[1] + + if not lhs_last_contracts or rhs_last_contracts: + if fallback: + import warnings + + if quantizer_sets is not None and quantizer_sets != [noop_quantizer_set] * len( + quantizer_sets + ): + warnings.warn( + f"TE einsum only supports NN layout. Equation '{equation}' is not NN layout. " + "Falling back to jnp.einsum. WARNING: Quantization will NOT be applied!", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise ValueError( + "TE einsum only supports NN layout (non-transposed matrix multiplication). Equation" + f" '{equation}' is not NN layout:\n - LHS '{gemm_info['lhs_spec']}': last dimension" + f" must contract (got contracting_dims={contracting_dims[0]})\n - RHS" + f" '{gemm_info['rhs_spec']}': last dimension must NOT contract (got" + f" contracting_dims={contracting_dims[1]})\nFor non-NN layouts (e.g., routing" + " operations), use jnp.einsum instead." + ) + + # Create vmapped dense function for batch dimensions + has_batch_dims = bool(batch_dims[0] or batch_dims[1]) + + if has_batch_dims: + # Validate single batch dimension (MoE use case) + if len(batch_dims[0]) != 1 or len(batch_dims[1]) != 1: + if fallback: + import warnings + + if quantizer_sets is not None and quantizer_sets != [noop_quantizer_set] * len( + quantizer_sets + ): + warnings.warn( + "TE einsum only supports single batch dimension. Got" + f" {len(batch_dims[0])} batch dims in lhs and {len(batch_dims[1])} in rhs." + " Falling back to jnp.einsum. WARNING: Quantization will NOT be applied!", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise NotImplementedError( + "Only single batch dimension is currently supported. " + f"Got {len(batch_dims[0])} batch dims in lhs and {len(batch_dims[1])} in rhs. " + f"Equation: '{equation}'" + ) + + lhs_batch_dim = batch_dims[0][0] + rhs_batch_dim = batch_dims[1][0] + + # Adjust contracting dims for the unbatched shapes seen by Python code + # (primitives will see batched shapes, but Python validation sees unbatched) + adj_lhs_contracting = tuple( + dim - (1 if dim > lhs_batch_dim else 0) for dim in contracting_dims[0] + ) + adj_rhs_contracting = tuple( + dim - (1 if dim > rhs_batch_dim else 0) for dim in contracting_dims[1] + ) + adj_contracting_dims = (adj_lhs_contracting, adj_rhs_contracting) + + # Stack quantizers into a pytree structure that vmap can handle + # QuantizerSet is already a pytree, so we can stack them + # For BF16 without quantizer_dim, this will be a stack of noop_quantizer_sets + stacked_quantizers = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *quantizer_sets) + + # Vmap over quantizers (or repeated noop quantizers for BF16) + def dense_with_quantizer(lhs_single, rhs_single, quantizer_set): + """Dense with explicit quantizer argument for vmapping.""" + return dense( + lhs_single, + rhs_single, + None, + contracting_dims=adj_contracting_dims, # Adjusted for unbatched shapes + transpose_batch_sequence=False, + input_axes=operand_axes[0], + kernel_axes=operand_axes[1], + output_axes=output_axes, + quantizer_set=quantizer_set, + ) + + vmapped_func = jax.vmap( + dense_with_quantizer, + in_axes=(lhs_batch_dim, rhs_batch_dim, 0), # vmap over stacked quantizers + out_axes=0, + ) + output = vmapped_func(lhs, rhs, stacked_quantizers) + else: + # No batch dimensions - direct dense call + # quantizer_set length already validated to be 1 + output = dense( + lhs, + rhs, + None, + contracting_dims=contracting_dims, + transpose_batch_sequence=False, + input_axes=operand_axes[0], + kernel_axes=operand_axes[1], + output_axes=output_axes, + quantizer_set=quantizer_sets[0], + ) + + return output diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 90f139c3dac..120bd05c13f 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -209,49 +209,63 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): flatten_axis: int has_rht_applied: bool - def __post_init__(self): - """Validates and adjusts the scale_inv shape after initialization. - - Ensures the scale_inv shape matches the expected shape based on the scaling mode - and quantization direction. Pads the scale_inv if necessary. - """ - assert self.flatten_axis > 0 - assert ( - 0 < self.flatten_axis < len(self.data.shape) - ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - - if self.scaling_mode == ScalingMode.NO_SCALING: - self.scale_inv = jnp.empty((0,), dtype=jnp.float32) - else: - unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, - data_layout=self.data_layout, - is_colwise=self.is_colwise, - is_padded=False, - # expect the flatten_axis wrt the N layout - flatten_axis=( - self.flatten_axis - if self.data_layout == "N" - else self.data.ndim - self.flatten_axis - ), - ) - unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( - self.data.shape, - data_layout=self.data_layout, - is_colwise=self.is_colwise, - is_padded=False, - # expect the flatten_axis wrt the N layout - flatten_axis=( - self.flatten_axis - if self.data_layout == "N" - else self.data.ndim - self.flatten_axis - ), - broadcast_2d_scale_shape_to_1d=True, - ) - assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), ( - f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or" - f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." - ) + # def __post_init__(self): + # """Validates and adjusts the scale_inv shape after initialization. + # + # Ensures the scale_inv shape matches the expected shape based on the scaling mode + # and quantization direction. Pads the scale_inv if necessary. + # """ + # assert self.flatten_axis > 0 + # assert ( + # 0 < self.flatten_axis < len(self.data.shape) + # ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" + # + # if self.scaling_mode == ScalingMode.NO_SCALING: + # self.scale_inv = jnp.empty((0,), dtype=jnp.float32) + # else: + # unpadded_scale_shape = self.scaling_mode.get_scale_shape( + # self.data.shape, + # data_layout=self.data_layout, + # is_colwise=self.is_colwise, + # is_padded=False, + # # expect the flatten_axis wrt the N layout + # flatten_axis=( + # self.flatten_axis + # if self.data_layout == "N" + # else self.data.ndim - self.flatten_axis + # ), + # ) + # unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( + # self.data.shape, + # data_layout=self.data_layout, + # is_colwise=self.is_colwise, + # is_padded=False, + # # expect the flatten_axis wrt the N layout + # flatten_axis=( + # self.flatten_axis + # if self.data_layout == "N" + # else self.data.ndim - self.flatten_axis + # ), + # broadcast_2d_scale_shape_to_1d=True, + # ) + # # Check shape, allowing for batch dimensions from vmap + # # If vmapped, shape will be (batch_size, *expected_shape) + # actual_shape = self.scale_inv.shape + # if actual_shape not in (unpadded_scale_shape, unpadded_scale_shape_broadcast): + # # Check if it's a batched version (extra leading dimensions) + # if len(actual_shape) > len(unpadded_scale_shape): + # # Batched: check that trailing dimensions match + # trailing_shape = actual_shape[-(len(unpadded_scale_shape)):] + # if trailing_shape not in (unpadded_scale_shape, unpadded_scale_shape_broadcast): + # raise AssertionError( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or " + # f"{unpadded_scale_shape_broadcast} (possibly with batch dims) but got {self.scale_inv.shape}." + # ) + # else: + # raise AssertionError( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or " + # f"{unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." + # ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. @@ -431,10 +445,21 @@ def __post_init__(self): flatten_axis=self.flatten_axis, ) - assert self.scale_inv.shape == expected_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv, got {self.scale_inv.shape}" - ) + # Check shape, allowing for batch dimensions from vmap + actual_shape = self.scale_inv.shape + if actual_shape != expected_scale_shape: + # Check if it's a batched version + if len(actual_shape) > len(expected_scale_shape): + trailing_shape = actual_shape[-(len(expected_scale_shape)) :] + assert trailing_shape == expected_scale_shape, ( + f"Unexpected scale_inv shape! Expected {expected_scale_shape} for padded " + f"scale_inv (possibly with batch dims), got {self.scale_inv.shape}" + ) + else: + raise AssertionError( + f"Unexpected scale_inv shape! Expected {expected_scale_shape} for padded " + f"scale_inv, got {self.scale_inv.shape}" + ) def tree_flatten(self): """Flattens the tensor for JAX tree operations.