diff --git a/.gitignore b/.gitignore index 74acd6ad7f5..04828a77862 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,4 @@ compile_commands.json .nfs tensor_dumps/ artifacts/ +.DS_Store \ No newline at end of file diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 18abe0f2c24..f47a31b6866 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -143,6 +143,70 @@ Tensor saving and restoring functions .. autoapifunction:: transformer_engine.pytorch.restore_from_saved +Operation fuser +--------------- + +.. autoapiclass:: transformer_engine.pytorch.ops.Sequential + :members: forward + +.. autoapiclass:: transformer_engine.pytorch.ops.FusibleOperation + :members: fuser_forward, fuser_backward + +.. autoapiclass:: transformer_engine.pytorch.ops.Linear + +.. autoapiclass:: transformer_engine.pytorch.ops.AddExtraInput + +.. autoapiclass:: transformer_engine.pytorch.ops.AllGather + +.. autoapiclass:: transformer_engine.pytorch.ops.AllReduce + +.. autoapiclass:: transformer_engine.pytorch.ops.BasicLinear + :members: _functional_forward, _functional_backward + +.. autoapiclass:: transformer_engine.pytorch.ops.Bias + +.. autoapiclass:: transformer_engine.pytorch.ops.ClampedSwiGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.ConstantScale + +.. autoapiclass:: transformer_engine.pytorch.ops.Dropout + +.. autoapiclass:: transformer_engine.pytorch.ops.GEGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.GELU + +.. autoapiclass:: transformer_engine.pytorch.ops.Identity + +.. autoapiclass:: transformer_engine.pytorch.ops.L2Normalization + +.. autoapiclass:: transformer_engine.pytorch.ops.LayerNorm + +.. autoapiclass:: transformer_engine.pytorch.ops.MakeExtraOutput + +.. autoapiclass:: transformer_engine.pytorch.ops.QGELU + +.. autoapiclass:: transformer_engine.pytorch.ops.QGEGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.Quantize + +.. autoapiclass:: transformer_engine.pytorch.ops.ReGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.ReLU + +.. autoapiclass:: transformer_engine.pytorch.ops.ReduceScatter + +.. autoapiclass:: transformer_engine.pytorch.ops.Reshape + +.. autoapiclass:: transformer_engine.pytorch.ops.RMSNorm + +.. autoapiclass:: transformer_engine.pytorch.ops.SReGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.SReLU + +.. autoapiclass:: transformer_engine.pytorch.ops.SiLU + +.. autoapiclass:: transformer_engine.pytorch.ops.SwiGLU + Deprecated functions -------------------- diff --git a/docs/examples/op_fuser/fp8_layernorm_linear.png b/docs/examples/op_fuser/fp8_layernorm_linear.png new file mode 100644 index 00000000000..b5916a61528 Binary files /dev/null and b/docs/examples/op_fuser/fp8_layernorm_linear.png differ diff --git a/docs/examples/op_fuser/layernorm_mlp.png b/docs/examples/op_fuser/layernorm_mlp.png new file mode 100644 index 00000000000..f388c88fa9d Binary files /dev/null and b/docs/examples/op_fuser/layernorm_mlp.png differ diff --git a/docs/examples/op_fuser/op_fuser.rst b/docs/examples/op_fuser/op_fuser.rst new file mode 100644 index 00000000000..18cae65a720 --- /dev/null +++ b/docs/examples/op_fuser/op_fuser.rst @@ -0,0 +1,251 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Operation fuser API +=================== + +Motivation +---------- + +Transformer Engine relies heavily on operation fusion to achieve high +performance. A typical training workload involves many memory-bound +operations such as activation functions and normalization, so +replacing them with fused kernels can deliver a significant +performance benefit. This is especially true for low-precision +training (e.g. FP8 and FP4) because it involves extra cast operations. + +Managing these fusions can be challenging because they differ based on +operation types, communication patterns, data types, and GPU +architectures. The most straightforward solution is to provide +monolithic modules like ``Linear``, ``LayerNormLinear``, or +``TransformerLayer``. These conform to the interface of a standard +PyTorch module, but can perform arbitrary fusions internally. These +hand-tuned implementations can achieve maximum performance, but they +tend to be complicated and difficult to modify. + +As an alternative to this "top-down" design, TE exposes a "bottom-up" +operation-based API. The user constructs individual operations and +passes them into a fuser, resulting in the same fused kernels as the +monolithic modules. This approach is more flexible, making it easier +to support new model architectures or to experiment with fusions. + +Description and usage +--------------------- + +Basic usage +^^^^^^^^^^^ + +At the most basic level, the operation fuser API involves two classes +in the ``transformer_engine.pytorch.ops`` submodule: + +- ``FusibleOperation``: An abstract base class for tensor operations. + Examples include ``Linear``, ``LayerNorm``, and ``AllReduce``. It is + a subclass of ``torch.nn.Module``, so it can hold trainable + parameters and can be called to perform the operation's forward + pass. +- ``Sequential``: A container of modules in sequential order. It has a + very similar interface as ``torch.nn.Sequential``. If it contains + any ``FusibleOperation`` s, then it may attempt to fuse them in the + forward and backward passes. + +Thus, using the operation fuser simply involves constructing +``FusibleOperation`` s and passing them into a ``Sequential``. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Options + hidden_size = 4096 + ffn_size = 28672 + batch_size = 16384 + + # Construct operations and fuse + mlp = te.ops.Sequential( + te.ops.LayerNorm(hidden_size), + te.ops.Linear(hidden_size, ffn_size), + te.ops.SwiGLU(), + te.ops.Linear(ffn_size // 2, hidden_size), + ) + + # Forward pass + x = torch.randn(batch_size, hidden_size, device="cuda") + y = mlp(x) + +.. figure:: ./layernorm_mlp.png + :align: center + + Operations that match ``LayerNormMLP`` module. Note that different + fusions have been applied in the forward and backward passes. + +Quantization +^^^^^^^^^^^^ + +The operation fuser respects TE's APIs for low-precision ("quantized") +data formats like FP8 and FP4. Constructing operations within a +``quantized_model_init`` context will enable quantized weights and +performing the forward pass within an ``autocast`` context will enable +quantized compute. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct layer with quantized weights + with te.quantized_model_init(): + fc1 = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.Linear(4096, 28672), + ) + + # Forward pass within autocast context + x = torch.randn(16384, 4096, device="cuda") + with te.autocast(): + y = fc1(x) + + # Backward pass outside of autocast context + y.sum().backward() + +.. figure:: ./fp8_layernorm_linear.png + :align: center + + Operations that match ``LayerNormLinear`` module with FP8 + quantization. + +Internally, each operation that supports quantized compute holds one +or more ``Quantizer`` s, which are builder classes for converting +high-precision tensors (e.g. in FP32 or BF16) to quantized tensors. In +order to enable fused quantization kernels, operations can access the +quantizers of neighboring operations and quantize eagerly. In some +situations, like when operations are split across multiple +``Sequential`` s, it may be helpful to encourage the fuser by manually +adding ``Quantize`` operations. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct layer with quantized weights + with te.quantized_model_init(): + norm = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.Quantize(), + ) + fc1 = te.ops.Sequential( + te.ops.Linear(4096, 28672), + ) + + # Forward pass + x = torch.randn(16384, 4096, device="cuda") + with te.autocast(): + y = norm(x) # y is a QuantizedTensor + z = fc1(y) + +.. warning:: + + This is an expert technique. Quantizer configurations can be quite + complicated, so the ``Quantize`` operation's quantizers may be + suboptimal. + +Branching operations +^^^^^^^^^^^^^^^^^^^^ + +The operation fuser supports very limited branching behavior. While +the operations must be in sequential order, some operations can accept +extra inputs or produce extra outputs. For example, ``AddExtraInput`` +will add an extra input tensor to the intermediate tensor and +``MakeExtraOutput`` will return the intermediate tensor as an extra +output. When calling a ``Sequential`` that contains any of these +branching operations, the extra inputs should be passed in as +arguments and the extra outputs will be returned. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct MLP with residual connection + fc1 = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.MakeExtraOutput(), # Output residual + te.ops.Linear(4096, 28672), + te.ops.SwiGLU(), + ) + fc2 = te.ops.Sequential( + te.ops.Linear(14336, 4096), + te.ops.AddExtraInput(), # Add residual + ) + + # Forward pass + x = torch.randn(16384, 4096, device="cuda") + y, residual = fc1(x) + y = fc2(y, residual) + +.. figure:: ./residual_layernorm_mlp.png + :align: center + + Operations for an MLP block with a residual connection. Note that + the block has been split into two sections, each with one branching + operation. + +Implementation details +^^^^^^^^^^^^^^^^^^^^^^ + +In addition to ``FusibleOperation`` and ``Sequential``, the fuser +infrastructure relies on the following classes: + +- ``BasicOperation``: The most basic type of ``FusibleOperation``. + Examples include ``BasicLinear``, ``Bias``, and ``ReLU``. It holds + parameters and state, and it implements both a forward and backward + pass. The ``op_forward`` and ``op_backward`` functions have an + interface reminiscent of ``torch.autograd.Function``, e.g. they + accept a context object that caches state from the forward pass to + the backward pass. +- ``FusedOperation``: A ``FusibleOperation`` that can replace one or + more ``BasicOperation`` s. Examples include + ``ForwardLinearBiasActivation`` and ``BackwardActivationBias``. Its + forward and backward passes (the ``fuser_forward`` and + ``fuser_backward`` functions) must produce equivalent results as its + corresponding ``BasicOperation`` s. This also means that the + ``FusedOperation`` is stateless since it can access parameters and + state from the ``BasicOperation`` s. Note that different fusions may + be applied in the forward and backward pass, so a ``FusedOperation`` + may be missing its forward and/or backward implementation. +- ``OperationFuser``: This is the class that manages the operation + fusions. It launches the forward and backward passes within a + ``torch.autograd.Function``. + +The first time that a ``Sequential`` is called, it will group adjacent +``FusibleOperation`` s together into ``OperationFuser`` s. The first +time an ``OperationFuser`` is called, it will attempt to fuse +operations for the forward pass and backward pass. Subsequent calls +will reuse the same state unless it has been invalidated, e.g. by +changing the quantization recipe. + +Misconceptions +-------------- + +- **The op fuser is not a general kernel compiler**: The op fuser API + is simply an alternative way to access TE fused kernels, most of + which are targeted toward common Transformer architectures. For + generic kernel compilation, consider tools like + `nvFuser `_, + `CuTe DSL `_, + `torch.compile `_, + `Triton `_, + or `Pallas `_. +- **The op fuser is not a graph compiler**: The op fuser only supports + operations in a sequential order, with very limited support for + branching operations. Support for general graphs is not planned + since it would massively increase complexity. +- **The op fuser is not interchangeable with the monolithic TE + modules**: Modules like ``Linear``, ``LayerNormLinear``, and + ``TransformerLayer`` support a wide range of features and advanced + workflows, which makes them challenging to decompose into simple + operations that work with the fuser. They are also carefully + hand-tuned to achieve maximum performance. diff --git a/docs/examples/op_fuser/residual_layernorm_mlp.png b/docs/examples/op_fuser/residual_layernorm_mlp.png new file mode 100644 index 00000000000..a47af7af618 Binary files /dev/null and b/docs/examples/op_fuser/residual_layernorm_mlp.png differ diff --git a/docs/index.rst b/docs/index.rst index 7a3ab9f6fd5..8aa5a3dbf8d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -49,6 +49,7 @@ Transformer Engine documentation examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/onnx/onnx_export.ipynb examples/te_jax_integration.ipynb + examples/op_fuser/op_fuser.rst .. toctree:: :hidden: diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index a444facd0a6..feb57c8eb08 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -153,7 +153,7 @@ class GELU(_ActivationOperation): \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) - See `Gaussian Error Linear Units (GELUs)`__. + See `Gaussian Error Linear Units (GELUs) `__. """ @@ -188,7 +188,7 @@ class GEGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + See `GLU Variants Improve Transformer `__. """ @@ -202,8 +202,8 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: class QGELU(_ActivationOperation): r"""Quick Gaussian Error Linear Unit - Quick GELU from `HuggingFace`__ - and `paper`__. + Quick GELU from `HuggingFace `__ + and `paper `__. .. math:: @@ -285,7 +285,7 @@ class ReGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + See `GLU Variants Improve Transformer `__. """ @@ -303,7 +303,7 @@ class SReLU(_ActivationOperation): \text{SReLU}(x) = \max(x^2,0) - See `Primer: Searching for Efficient Transformers for Language Modeling`__. + See `Primer: Searching for Efficient Transformers for Language Modeling `__. """ @@ -383,8 +383,8 @@ class SwiGLU(_ActivationOperation): The Sigmoid Linear Unit (SiLU) gating function is also known as the swish function. See - `GLU Variants Improve Transformer`__ - and `Gaussian Error Linear Units (GELUs)`__. + `GLU Variants Improve Transformer `__ + and `Gaussian Error Linear Units (GELUs) `__ . """ @@ -397,14 +397,18 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: class ClampedSwiGLU(_ActivationOperation): r"""GPT-OSS - Implementation based on `GPT-OSS`__. + Implementation based on `GPT-OSS `__. This activation has two differences compared to the original SwiGLU 1. Both gate and pre-activations are clipped based on parameter limit. 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt - from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + .. warning:: + + The input tensor is chunked along the last dimension to get + gates/pre-activations which is differnt from GPT OSS + implementation where the gates/pre-activations are assumed to + be interleaved in the input tensor. Parameters ---------- @@ -414,6 +418,7 @@ class ClampedSwiGLU(_ActivationOperation): The scaling factor for the sigmoid function used in the activation. cache_quantized_input : bool, default = False Quantize input tensor when caching for use in the backward pass. + """ def __init__( diff --git a/transformer_engine/pytorch/ops/basic/add_extra_input.py b/transformer_engine/pytorch/ops/basic/add_extra_input.py index 1fcfa0466ad..2da77369593 100644 --- a/transformer_engine/pytorch/ops/basic/add_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/add_extra_input.py @@ -30,7 +30,7 @@ class AddExtraInput(BasicOperation): feature and most users are discouraged from it. In-place operations break some autograd assumptions and they can result in subtle, esoteric bugs. - Compare to `MakeExtraOutput`, which does a similar operation in + Compare to ``MakeExtraOutput``, which does a similar operation in the backward pass. """ diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 9f09e6634be..e5009b9ce90 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -48,8 +48,8 @@ def _wait_async(handle: Optional[Any]) -> None: class BasicLinear(BasicOperation): """Apply linear transformation: :math:`y = x A^T` - This is a drop-in replacement for `torch.nn.Linear` with - `bias=False`. + This is a drop-in replacement for ``torch.nn.Linear`` with + ``bias=False``. Parameters ---------- @@ -61,27 +61,27 @@ class BasicLinear(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel_mode : {`None`, "column", "row"}, default = `None` + tensor_parallel_mode : {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group : torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel : bool, default = `False` + sequence_parallel : bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) rng_state_tracker_function : callable - Function that returns `CudaRNGStatesTracker`, which is used + Function that returns ``CudaRNGStatesTracker``, which is used for model-parallel weight initialization - accumulate_into_main_grad : bool, default = `False` + accumulate_into_main_grad : bool, default = False Whether to directly accumulate weight gradients into the - weight's `main_grad` attribute instead of relying on PyTorch - autograd. The weight's `main_grad` must be set externally and - there is no guarantee that `grad` will be set or be + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that ``grad`` will be set or be meaningful. This is primarily intented to integrate with Megatron-LM. This argument along with weight tensor having - attribute 'overwrite_main_grad' set to True will overwrite - `main_grad` instead of accumulating. + attribute ``overwrite_main_grad`` set to ``True`` will + overwrite ``main_grad`` instead of accumulating. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -184,7 +184,7 @@ def _canonicalize_tensor_parallelism( Parameters ---------- - mode: {`None`, "column", "row"} + mode: {None, "column", "row"} Mode for tensor parallelism process_group: torch.distributed.ProcessGroup Process group for tensor parallelism @@ -200,7 +200,7 @@ def _canonicalize_tensor_parallelism( Returns ------- - mode: {`None`, "column", "row"} + mode: {None, "column", "row"} Mode for tensor parallelism process_group: torch.distributed.ProcessGroup Process group for tensor parallelism @@ -440,18 +440,18 @@ def _functional_forward( Output tensor beta: float, optional Scaling factor applied to original value of out when accumulating into it - accumulate_into_out: bool, default = `False` + accumulate_into_out: bool, default = False Add result to output tensor instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. @@ -459,10 +459,10 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. - input_requires_grad: bool, default = `True` + input_requires_grad: bool, default = True Whether the loss gradient w.r.t. the input tensor is required in the backward pass. - weight_requires_grad: bool, default = `True` + weight_requires_grad: bool, default = True Whether the loss gradient w.r.t. the weight tensor is required in the backward pass. @@ -471,11 +471,11 @@ def _functional_forward( torch.Tensor Output tensor torch.Tensor, optional - Input tensor, ready for use in backward pass. `None` is + Input tensor, ready for use in backward pass. ``None`` is returned if loss gradient w.r.t. the weight tensor is not required. torch.Tensor, optional - Weight tensor, ready for use in backward pass. `None` is + Weight tensor, ready for use in backward pass. ``None`` is returned if loss gradient w.r.t. the input tensor is not required. @@ -676,24 +676,24 @@ def _functional_backward( Loss gradient w.r.t. weight tensor grad_weight_beta: float, optional Scaling factor applied to original value of grad_weight when accumulating into it - accumulate_into_grad_weight: bool, default = `False` + accumulate_into_grad_weight: bool, default = False Add result to weight grad instead of overwriting grad_input: torch.Tensor, optional Loss gradient w.r.t. input tensor grad_input_beta: float, optional Scaling factor applied to original value of grad_input when accumulating into it - accumulate_into_grad_input: bool, default = `False` + accumulate_into_grad_input: bool, default = False Add result to input grad instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 69101638253..6c3c0538c11 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -18,7 +18,7 @@ class Bias(BasicOperation): """Apply additive bias - This is equivalent to the additive bias in `torch.nn.Linear`. + This is equivalent to the additive bias in ``torch.nn.Linear``. Parameters ---------- @@ -28,7 +28,7 @@ class Bias(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel : bool, default = `False` + tensor_parallel : bool, default = False Whether to distribute input tensor and bias tensors along inner dimension tensor_parallel_group : torch.distributed.ProcessGroup, default = world group diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3922f85cad4..45aed4f4cf4 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -31,7 +31,7 @@ class LayerNorm(BasicOperation): r"""Layer Normalization Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization `__ + the paper `Layer Normalization `__ . .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta @@ -51,9 +51,9 @@ class LayerNorm(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero - and the calculation changes to + zero_centered_gamma : bool, default = False + If ``True``, the :math:`\gamma` parameter is initialized to + zero and the calculation changes to .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index 34228affc75..7272ebc1894 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -35,7 +35,7 @@ class MakeExtraOutput(BasicOperation): operations break some autograd assumptions and they can result in subtle, esoteric bugs. - Compare to `AddExtraInput`, which does a similar operation in the + Compare to ``AddExtraInput``, which does a similar operation in the backward pass. """ diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 1278701a9bb..c2c46e40199 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -18,14 +18,14 @@ class Quantize(BasicOperation): """Quantize tensor data - Uses recipe from `autocast` context. When called outside - of an `autocast` context, this is an identity operation. + Uses recipe from ``autocast`` context. When called outside + of an ``autocast`` context, this is an identity operation. Parameters ---------- - forward : bool, default = `True` + forward : bool, default = True Perform quantization in forward pass - backward : bool, default = `False` + backward : bool, default = False Perform quantization in backward pass """ diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index fcdb3b0bbec..ee8cf4464af 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -20,7 +20,7 @@ class Reshape(BasicOperation): """Reshape tensor - See `torch.reshape`. + See ``torch.reshape``. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 316c292c537..59a52183858 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -32,7 +32,7 @@ class RMSNorm(BasicOperation): Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in the paper - `Root Mean Square Layer Normalization `__ + `Root Mean Square Layer Normalization `__ . .. math:: y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma @@ -50,8 +50,8 @@ class RMSNorm(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero + zero_centered_gamma : bool, default = False + If ``True``, the :math:`\gamma` parameter is initialized to zero and the calculation changes to .. math:: diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 5149aa1ffb3..0d2fae2048e 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -125,18 +125,18 @@ def _functional_backward( Tensor datatype grad_weight: torch.Tensor, optional Loss gradient w.r.t. weight tensor - accumulate_into_grad_weight: bool, default = `False` + accumulate_into_grad_weight: bool, default = False Add result to weight grad instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 517632d6514..fbcfd4b46af 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -115,16 +115,16 @@ def _functional_forward( Tensor device dtype: torch.dtype Tensor datatype - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. @@ -132,10 +132,10 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. - input_requires_grad: bool, default = `True` + input_requires_grad: bool, default = True Whether the loss gradient w.r.t. the input tensor is required in the backward pass. - weight_requires_grad: bool, default = `True` + weight_requires_grad: bool, default = True Whether the loss gradient w.r.t. the weight tensor is required in the backward pass. ub_comm_name: str diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index fecf28f0a9d..9e865ca42d4 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -44,7 +44,7 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: def _is_graph_capturing() -> bool: - """Whether function is called within `make_graphed_callables` + """Whether function is called within ``make_graphed_callables`` Avoid circular import with lazy import. diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index d1e63822911..d28c03f163e 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -23,7 +23,7 @@ class Linear(FusedOperation): """Apply linear transformation: :math:`y = x A^T + b` - This is a drop-in replacement for `torch.nn.Linear`. + This is a drop-in replacement for ``torch.nn.Linear``. Parameters ---------- @@ -31,17 +31,17 @@ class Linear(FusedOperation): Inner dimension of input tensor out_features : int Inner dimension of output tensor - bias : bool, default = `True` + bias : bool, default = True Apply additive bias device : torch.device, default = default CUDA device Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel_mode : {`None`, "column", "row"}, default = `None` + tensor_parallel_mode : {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group : torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel : bool, default = `False` + sequence_parallel : bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing @@ -49,11 +49,11 @@ class Linear(FusedOperation): rng_state_tracker_function : callable Function that returns CudaRNGStatesTracker, which is used for model-parallel weight initialization - accumulate_into_main_grad : bool, default = `False` + accumulate_into_main_grad : bool, default = False Whether to directly accumulate weight gradients into the - weight's `main_grad` attribute instead of relying on PyTorch - autograd. The weight's `main_grad` must be set externally and - there is no guarantee that `grad` will be set or be + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally and + there is no guarantee that ``grad`` will be set or be meaningful. This is primarily intented to integrate with Megatron-LM. diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 421c92b8235..20227c6cf41 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -94,7 +94,7 @@ def fuser_forward( several of this function's arguments are lists of arguments to forward functions of corresponding basic ops. - Called by `OperationFuser`. + Called by ``OperationFuser``. Parameters ---------- @@ -141,7 +141,7 @@ def fuser_backward( several of this function's arguments are lists of arguments to backward functions of corresponding basic ops. - Called by `OperationFuser`. + Called by ``OperationFuser``. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index 2afda58e47d..de545419e5b 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -15,10 +15,10 @@ class Sequential(torch.nn.Module): - """Sequential container for fusible operations + """Sequential container for fusible operations. - This is a drop-in replacement for `torch.nn.Sequential`, with - support for fusing `FusibleOperation`s. + This is a drop-in replacement for ``torch.nn.Sequential`` with + support for fusing ``FusibleOperation`` s. Parameters ----------