From 74ca30f1b53bfcc64263aacf587b4343146693ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 21:12:54 +0000 Subject: [PATCH 1/2] [pre-commit.ci] pre-commit suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v6.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v6.0.0) - https://github.com/psf/black → https://github.com/psf/black-pre-commit-mirror - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 25.12.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...25.12.0) - [github.com/pre-commit/mirrors-clang-format: v18.1.6 → v21.1.8](https://github.com/pre-commit/mirrors-clang-format/compare/v18.1.6...v21.1.8) - [github.com/netromdk/vermin: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3 → v1.8.0](https://github.com/netromdk/vermin/compare/c75aca72f4e85c6e47252139e8695f1c8b5f9ae3...v1.8.0) --- .pre-commit-config.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) mode change 100755 => 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100755 new mode 100644 index 5043d6ea22..d7a4ef70b0 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: check-merge-conflict - id: check-added-large-files @@ -23,8 +23,8 @@ repos: - id: trailing-whitespace files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$ - - repo: https://github.com/psf/black - rev: 24.4.2 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 25.12.0 hooks: - id: black name: Format python code @@ -32,7 +32,7 @@ repos: types: [python] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.6 + rev: v21.1.8 hooks: - id: clang-format entry: clang-format -i @@ -40,7 +40,7 @@ repos: files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$ - repo: https://github.com/netromdk/vermin - rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3 + rev: v1.8.0 hooks: - id: vermin args: ['-t=3.10', '--violations'] From 076b995497e210f2ef34b8ebedb82eac8863d635 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 21:14:36 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- build_tools/jax.py | 1 + build_tools/pytorch.py | 1 + build_tools/te_version.py | 1 + examples/jax/collective_gemm/conftest.py | 1 + .../jax/collective_gemm/test_dense_grad.py | 1 + examples/jax/collective_gemm/test_gemm.py | 1 + .../test_layernorm_mlp_grad.py | 1 + examples/jax/encoder/common.py | 1 + examples/jax/encoder/conftest.py | 1 + .../encoder/test_model_parallel_encoder.py | 1 + examples/jax/encoder/test_multigpu_encoder.py | 1 + .../encoder/test_multiprocessing_encoder.py | 5 +- .../jax/encoder/test_single_gpu_encoder.py | 1 + examples/jax/mnist/test_single_gpu_mnist.py | 1 + tests/jax/conftest.py | 1 + tests/jax/test_distributed_dense.py | 1 - tests/jax/test_distributed_fused_attn.py | 1 - tests/jax/test_distributed_layernorm.py | 1 - tests/jax/test_distributed_layernorm_mlp.py | 1 - tests/jax/test_fused_attn.py | 1 + tests/jax/test_layer.py | 3 +- ..._multi_process_distributed_grouped_gemm.py | 1 - tests/jax/test_permutation.py | 1 - tests/jax/test_softmax.py | 1 + tests/pytorch/attention/test_attention.py | 2 +- tests/pytorch/attention/test_cp_utils.py | 1 + tests/pytorch/debug/test_numerics.py | 18 +-- tests/pytorch/debug/test_sanity.py | 12 +- .../pytorch/distributed/run_numerics_exact.py | 1 - tests/pytorch/distributed/test_fusible_ops.py | 1 - .../test_fusible_ops_with_userbuffers.py | 1 - tests/pytorch/distributed/test_torch_fsdp2.py | 1 - tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 1 - .../pytorch/nvfp4/test_nvfp4_module_exact.py | 13 +- .../nvfp4/test_nvfp4_quantize_exact.py | 1 - .../test_float8_current_scaling_exact.py | 37 +++-- tests/pytorch/test_fused_rope.py | 4 +- tests/pytorch/test_multi_tensor.py | 1 - tests/pytorch/test_numerics.py | 1 - tests/pytorch/test_partial_cast.py | 1 - tests/pytorch/test_permutation.py | 8 +- transformer_engine/common/common.h | 136 +++++++++++++----- .../common/fused_router/utils.h | 28 +++- .../common/multi_tensor/l2norm.cu | 4 +- .../common/normalization/kernel_traits.h | 2 +- .../layernorm/ln_bwd_kernels.cuh | 6 +- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 4 +- .../layernorm/ln_fwd_cuda_kernel.cu | 4 +- .../rmsnorm/rmsnorm_bwd_kernels.cuh | 6 +- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 4 +- .../rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 4 +- transformer_engine/common/recipe/__init__.py | 1 + .../common/swizzle/swizzle_block_scaling.cu | 4 +- .../common/triton/permutation.py | 1 - transformer_engine/common/utils.py | 1 + .../debug/features/log_fp8_tensor_stats.py | 1 - .../debug/features/utils/stats_buffer.py | 1 - transformer_engine/jax/__init__.py | 1 - transformer_engine/jax/activation.py | 2 +- transformer_engine/jax/attention.py | 1 + transformer_engine/jax/checkpoint_policies.py | 1 - .../jax/cpp_extensions/__init__.py | 1 + .../jax/cpp_extensions/activation.py | 8 +- transformer_engine/jax/cpp_extensions/amax.py | 2 +- .../jax/cpp_extensions/attention.py | 10 +- transformer_engine/jax/cpp_extensions/base.py | 1 + transformer_engine/jax/cpp_extensions/gemm.py | 15 +- .../jax/cpp_extensions/normalization.py | 2 +- .../jax/cpp_extensions/quantization.py | 2 +- .../jax/cpp_extensions/softmax.py | 2 +- transformer_engine/jax/flax/__init__.py | 1 + transformer_engine/jax/flax/module.py | 1 + transformer_engine/jax/flax/transformer.py | 1 + transformer_engine/jax/quantize/__init__.py | 1 + .../jax/quantize/dequantizer.py | 2 +- transformer_engine/jax/quantize/hadamard.py | 1 + transformer_engine/jax/quantize/metadata.py | 2 +- transformer_engine/jax/quantize/misc.py | 1 + transformer_engine/jax/quantize/quantizer.py | 1 + .../jax/quantize/scaling_modes.py | 1 - transformer_engine/jax/quantize/tensor.py | 1 + transformer_engine/jax/sharding.py | 1 + transformer_engine/jax/softmax.py | 3 +- .../jax/triton_extensions/permutation.py | 1 - .../jax/triton_extensions/utils.py | 1 - .../dot_product_attention/backends.py | 1 + .../dot_product_attention/context_parallel.py | 3 +- .../dot_product_attention.py | 2 +- .../dot_product_attention/softmax.py | 2 +- .../attention/dot_product_attention/utils.py | 6 +- .../pytorch/attention/inference.py | 1 + .../pytorch/attention/multi_head_attention.py | 1 + transformer_engine/pytorch/attention/rope.py | 2 +- transformer_engine/pytorch/constants.py | 2 +- .../pytorch/cpp_extensions/__init__.py | 1 + .../pytorch/cpp_extensions/fused_attn.py | 2 +- .../pytorch/cpp_extensions/gemm.py | 1 - transformer_engine/pytorch/cpu_offload.py | 1 - transformer_engine/pytorch/cpu_offload_v1.py | 1 + .../pytorch/custom_recipes/utils.py | 1 - transformer_engine/pytorch/distributed.py | 2 +- transformer_engine/pytorch/export.py | 1 - transformer_engine/pytorch/fp8.py | 1 - transformer_engine/pytorch/graph.py | 1 + transformer_engine/pytorch/jit.py | 1 + transformer_engine/pytorch/module/__init__.py | 1 + transformer_engine/pytorch/module/base.py | 1 + .../pytorch/module/fp8_padding.py | 3 +- .../pytorch/module/fp8_unpadding.py | 3 +- .../pytorch/module/grouped_linear.py | 1 + .../pytorch/module/layernorm.py | 1 + .../pytorch/module/layernorm_linear.py | 3 +- .../pytorch/module/layernorm_mlp.py | 12 +- transformer_engine/pytorch/module/linear.py | 3 +- transformer_engine/pytorch/module/rmsnorm.py | 1 + transformer_engine/pytorch/numerics_debug.py | 1 + .../pytorch/ops/basic/basic_linear.py | 2 +- .../pytorch/ops/fused/backward_linear_add.py | 2 +- .../ops/fused/backward_linear_scale.py | 2 +- .../ops/fused/userbuffers_backward_linear.py | 2 +- transformer_engine/pytorch/ops/linear.py | 2 +- .../pytorch/optimizers/__init__.py | 1 + .../pytorch/optimizers/fused_adam.py | 1 + .../pytorch/optimizers/fused_sgd.py | 1 + .../pytorch/optimizers/multi_tensor_apply.py | 1 + transformer_engine/pytorch/permutation.py | 1 + transformer_engine/pytorch/quantization.py | 2 +- transformer_engine/pytorch/router.py | 1 + transformer_engine/pytorch/setup.py | 1 - .../pytorch/tensor/float8_blockwise_tensor.py | 1 + .../pytorch/tensor/float8_tensor.py | 3 +- .../pytorch/tensor/mxfp8_tensor.py | 1 + .../pytorch/tensor/nvfp4_tensor.py | 1 + transformer_engine/pytorch/torch_version.py | 1 + transformer_engine/pytorch/transformer.py | 2 +- transformer_engine/pytorch/utils.py | 2 +- 136 files changed, 312 insertions(+), 189 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 276c9943d6..e95ece0b97 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """JAX related extensions.""" + import os from pathlib import Path from packaging import version diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index b4815a0942..f977eef54f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """PyTorch related extensions.""" + import os from pathlib import Path diff --git a/build_tools/te_version.py b/build_tools/te_version.py index f4a1a587ed..e80f75c315 100644 --- a/build_tools/te_version.py +++ b/build_tools/te_version.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Transformer Engine version string.""" + import os from pathlib import Path import subprocess diff --git a/examples/jax/collective_gemm/conftest.py b/examples/jax/collective_gemm/conftest.py index 5be5709ba7..b700309d7b 100644 --- a/examples/jax/collective_gemm/conftest.py +++ b/examples/jax/collective_gemm/conftest.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """config for collective_gemm tests""" + import pytest diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 94c7dc5b66..7bb36eb592 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Collective Dense Gradient test on multi-GPU with tensor parallelism""" + import argparse import unittest import os diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index d2994723bb..8fae5ef227 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -12,6 +12,7 @@ Example: python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3 """ + import unittest import os from functools import partial diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 61c960a7aa..eb400d15a7 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Collective Dense Gradient test on multi-GPU with tensor parallelism""" + import argparse import unittest import os diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index 7906d44aec..67c646719e 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Shared functions for the encoder tests""" + from functools import lru_cache import os import pathlib diff --git a/examples/jax/encoder/conftest.py b/examples/jax/encoder/conftest.py index 083c1b4dce..a493d7e5a3 100644 --- a/examples/jax/encoder/conftest.py +++ b/examples/jax/encoder/conftest.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """config for test_multiprocessing_encoder""" + import pytest diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index b534db8576..b845ef62f9 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Encoder training on multi-GPU with tesnor parallelism""" + import argparse import unittest from functools import partial diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 98184ccd75..d37be34aa8 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Encoder training on multi-GPU with data parallelism""" + import argparse import unittest from functools import partial diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 327540521c..0b5949905a 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Encoder training with multi-GPU, multiprocessing, and tensor parallelism""" + import argparse import os import unittest @@ -104,7 +105,7 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False inputs = jnp.asarray(dataset) total_input_size = len(inputs) - (dp_size, tp_size) = mesh.device_ids.shape + dp_size, tp_size = mesh.device_ids.shape valid_input_size, global_batch_size, num_steps, tp_group_id = valid_shard_size( total_input_size, batch_size, dp_size, tp_size ) @@ -156,7 +157,7 @@ def train_epoch( """Train for a single epoch.""" total_batch_size = len(train_ds["sentence"]) - (dp_size, tp_size) = mesh.device_ids.shape + dp_size, tp_size = mesh.device_ids.shape valid_size, _, num_steps, tp_group_id = valid_shard_size( total_batch_size, batch_size, dp_size, tp_size ) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 82c7fed38e..147cc88bfd 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Encoder training on single GPU""" + import argparse import unittest from functools import partial diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 0c76d51c37..f6022cd0a5 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """MNIST training on single GPU""" + import argparse import unittest from functools import partial diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 6b7520d147..8c6421e0f0 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """conftest for tests/jax""" + import os import jax import pytest diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py index b8caf188d4..025181a87f 100644 --- a/tests/jax/test_distributed_dense.py +++ b/tests/jax/test_distributed_dense.py @@ -18,7 +18,6 @@ from transformer_engine.jax import autocast from transformer_engine.jax.dense import dense - DTYPES = [jnp.bfloat16] GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in] diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d0018543d1..bd6c91a9b7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -27,7 +27,6 @@ ReorderStrategy, ) - DTYPES = [jnp.bfloat16] DISTRIBUTED_SELF_ATTN_DATA_SHAPES = { diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index e9a2fa49e2..dfed3dad0c 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -20,7 +20,6 @@ from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available - DTYPES = [jnp.bfloat16, jnp.float32] NORM_INPUT_SHAPES = { diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index d214597cb3..447eea074e 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -45,7 +45,6 @@ ) from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability - is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index ac1b7c3505..85509f2c9d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Tests for fused attention""" + from enum import Enum, auto from dataclasses import dataclass, field from functools import partial diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 8c16d162ed..18747a5c83 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Test transformer_engine.jax.flax.TransformerLayer""" + import os from functools import partial from typing import Dict, Tuple, Optional @@ -209,7 +210,7 @@ def enable_fused_attn(): }, # attrs20 { - _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")), + _KEY_OF_MLP_ACTIVATIONS: ("relu", "relu"), }, # attrs21 { diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index 94fed0859f..3efc0f4edd 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -16,7 +16,6 @@ from utils import assert_allclose, dtype_tols - N_GROUP = 8 MESH_AXIS_NAME = "fsdp" diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 43f2553eed..c9faa1d3f5 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -18,7 +18,6 @@ ) from utils import assert_allclose, pytest_parametrize_wrapper - # ============================================================================= # Test parameter definitions with L0 (fast) and L2 (comprehensive) levels # ============================================================================= diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 7af9613538..6f5ffea672 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Tests for the softmax primitives""" + from contextlib import nullcontext from dataclasses import dataclass from functools import wraps diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..fd08fe8487 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2629,7 +2629,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_DPA"): saved_tensors = ctx.saved_tensors - (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved( + q, k, v, inp_fp8, qkv_weight_fp8, out = restore_from_saved( ctx.tensor_objects, saved_tensors ) diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py index e5051aab36..1d5da55bf5 100644 --- a/tests/pytorch/attention/test_cp_utils.py +++ b/tests/pytorch/attention/test_cp_utils.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Unit tests for context parallel utils.""" + import torch import unittest from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( diff --git a/tests/pytorch/debug/test_numerics.py b/tests/pytorch/debug/test_numerics.py index ab9a2d054a..70f73c612a 100644 --- a/tests/pytorch/debug/test_numerics.py +++ b/tests/pytorch/debug/test_numerics.py @@ -275,8 +275,7 @@ def _get_tensors(): """ -DISABLE_FP8_CONFIG = Template( - """disable_fp8_config: +DISABLE_FP8_CONFIG = Template("""disable_fp8_config: enabled: True layers: layer_types: [linear] @@ -284,8 +283,7 @@ def _get_tensors(): DisableFP8GEMM: enabled: True gemms: [$gemms] -""" -) +""") @create_config_file @@ -407,8 +405,7 @@ def test_per_tensor_scaling( ) -PER_TENSOR_SCALING_CONFIG = Template( - """per_tensor_scaling_config: +PER_TENSOR_SCALING_CONFIG = Template("""per_tensor_scaling_config: enabled: True layers: layer_types: [linear] @@ -417,8 +414,7 @@ def test_per_tensor_scaling( enabled: True gemms_struct: $gemms -""" -) +""") def _prepare_per_tensor_scaling_config( @@ -670,8 +666,7 @@ def test_fake_quant_fp8( ) -FAKE_QUANT_CONFIG = Template( - """fake_quant_config: +FAKE_QUANT_CONFIG = Template("""fake_quant_config: enabled: True layers: layer_types: [linear] @@ -680,8 +675,7 @@ def test_fake_quant_fp8( enabled: True gemms_struct: $gemms -""" -) +""") def fake_quant_fp8_create_config( diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index aee5474e76..74063a6c22 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -18,7 +18,8 @@ configs = { "": "", - "log": """log: + "log": ( + """log: layers: layer_types: [linear] enabled: @@ -36,8 +37,10 @@ stats: [underflows, overflows] start_step : 0 end_step: 1 -""", - "fake_quant": """ +""" + ), + "fake_quant": ( + """ fake_quant_config: enabled: True layers: @@ -47,7 +50,8 @@ enabled: True gemms: [fprop, dgrad, wgrad] quant_format: FP8E5M2 -""", +""" + ), } diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 0f3d2cbbf0..59265a6504 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -26,7 +26,6 @@ from transformer_engine.pytorch.custom_recipes import utils from run_layer_with_overlap import _compare_tensors - BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128 WORLD_RANK, WORLD_SIZE = None, None NCCL_WORLD = None diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index c484038938..f1e6234536 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -36,7 +36,6 @@ sys.path.append(str(_current_file.parent.parent)) from utils import dtype_tols, make_recipe, quantization_tols - # Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 603433e0da..07e3d200a4 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -34,7 +34,6 @@ Float8Tensor, ) - # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index e328e57758..dc5d96bd66 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -10,7 +10,6 @@ import torch - fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) NUM_PROCS: int = torch.cuda.device_count() diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..cae6fcc247 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -11,7 +11,6 @@ from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils - recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index a96fea3af0..8a2daeab52 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -9,7 +9,6 @@ from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 from transformer_engine.pytorch.custom_recipes import utils - recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -244,9 +243,7 @@ def check_nvfp4_module_versus_reference( native_outputs.append( { "output": y_native.detach().clone(), - "input_grad": ( - x_native.grad.detach().clone() if x_native.grad is not None else None - ), + "input_grad": x_native.grad.detach().clone() if x_native.grad is not None else None, "weight_grad": ( native_module.weight.grad.detach().clone() if native_module.weight.grad is not None @@ -263,7 +260,7 @@ def check_nvfp4_module_versus_reference( ref_outputs.append( { "output": y_ref.detach().clone(), - "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "input_grad": x_ref.grad.detach().clone() if x_ref.grad is not None else None, "weight_grad": ( ref_module.weight.grad.detach().clone() if ref_module.weight.grad is not None @@ -467,9 +464,7 @@ def check_nvfp4_layernorm_linear_versus_reference( { "output": y_native.detach().clone(), "ln_out": ln_out_native.detach().clone(), - "input_grad": ( - x_native.grad.detach().clone() if x_native.grad is not None else None - ), + "input_grad": x_native.grad.detach().clone() if x_native.grad is not None else None, "weight_grad": ( native_module.weight.grad.detach().clone() if native_module.weight.grad is not None @@ -486,7 +481,7 @@ def check_nvfp4_layernorm_linear_versus_reference( { "output": y_ref.detach().clone(), "ln_out": ln_out_ref.detach().clone(), - "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "input_grad": x_ref.grad.detach().clone() if x_ref.grad is not None else None, "weight_grad": ( ref_module.weight.grad.detach().clone() if ref_module.weight.grad is not None diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 80ccb2f23d..198c762702 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -12,7 +12,6 @@ from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType - recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 99ab9c4984..154dab8d96 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -18,7 +18,6 @@ CurrentScalingQuantizerRef, ) - # read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" tensor_dump_dir_env = os.getenv("NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") @@ -110,10 +109,18 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { - "y": f"y_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", - "dgrad": f"dgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", - "wgrad": f"wgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", - "bgrad": f"bgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "y": ( + f"y_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), + "dgrad": ( + f"dgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), + "wgrad": ( + f"wgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), + "bgrad": ( + f"bgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), } if not use_bias: @@ -462,11 +469,21 @@ def _check_golden_tensor_dumps( current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { - "y": f"y_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", - "ln_out": f"ln_out_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", - "dgrad": f"dgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", - "wgrad": f"wgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", - "bgrad": f"bgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "y": ( + f"y_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), + "ln_out": ( + f"ln_out_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), + "dgrad": ( + f"dgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), + "wgrad": ( + f"wgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), + "bgrad": ( + f"bgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt" + ), } if not use_bias: diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 50624df9e0..941f52bcab 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -418,9 +418,7 @@ def test_fused_qkv_rope( # for more accurate comparison t_clone = t.clone() - (query, key, value) = torch.split( - t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3 - ) + query, key, value = torch.split(t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3) query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size) query_unfused = apply_rotary_pos_emb( diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index b7caa094ae..ff0f6c2885 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -12,7 +12,6 @@ from references.quantize_scale_calc import scale_from_amax_tensor - input_size_pairs = [ (7777 * 77, 555 * 555), (777, 555), diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index abe2806e66..3b9c92e6d1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -51,7 +51,6 @@ import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states - # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) diff --git a/tests/pytorch/test_partial_cast.py b/tests/pytorch/test_partial_cast.py index bbb18503b1..70161b7eef 100644 --- a/tests/pytorch/test_partial_cast.py +++ b/tests/pytorch/test_partial_cast.py @@ -11,7 +11,6 @@ from transformer_engine.pytorch import is_mxfp8_available from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier - mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index be1ff30472..13481afdd3 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -970,8 +970,8 @@ def _test_permutation_and_padding_with_merging_probs( num_out_tokens = num_tokens * topK print( - "permutation and padding with merging probs:" - f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}" + f"permutation and padding with merging probs: token:{num_tokens} hidden_size:{hidden_size}" + f" expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}" ) # Convert TE dtypes to PyTorch dtypes @@ -1293,8 +1293,8 @@ def _test_moe_chunk_sort( BENCHMARK=False, ): print( - "chunk permute:" - f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}" + f"chunk permute: token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert}" + f" tp_size:{tp_size} {te_dtype}" ) # Convert TE dtypes to PyTorch dtypes diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0bc9536844..37c0ccca34 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -574,7 +574,9 @@ struct TypeInfo { #define SWITCH_FP4_TYPE_HANDLE(type, ...) \ case DType::kFloat4E2M1: { \ using type = fp4e2m1; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; #else #define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing @@ -585,43 +587,63 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kByte: { \ using type = unsigned char; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt16: { \ using type = int16_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt32: { \ using type = int32_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt64: { \ using type = int64_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E8M0: { \ using type = byte; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ @@ -633,23 +655,33 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -660,23 +692,33 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -687,15 +729,21 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -707,7 +755,9 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat4E2M1: { \ using type = __nv_fp4x2_storage_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -718,11 +768,15 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -733,15 +787,21 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: \ case DType::kFloat8E4M3: { \ @@ -775,11 +835,15 @@ struct TypeInfo { switch (SCALE_DIM) { \ case 1: { \ constexpr size_t DIM = 1; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case 32: { \ constexpr size_t DIM = 32; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: { \ NVTE_ERROR("Invalid size of the MX scaling factor."); \ @@ -789,10 +853,14 @@ struct TypeInfo { #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ if (CONDITION) { \ constexpr bool FLAG = true; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } else { \ constexpr bool FLAG = false; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 4ae0b467b5..af46e9e442 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -221,15 +221,21 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -240,19 +246,27 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i using namespace transformer_engine; \ case DType::kInt32: { \ using type = int32_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt64: { \ using type = int64_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ diff --git a/transformer_engine/common/multi_tensor/l2norm.cu b/transformer_engine/common/multi_tensor/l2norm.cu index 8a7f265d40..8c43c3e27c 100644 --- a/transformer_engine/common/multi_tensor/l2norm.cu +++ b/transformer_engine/common/multi_tensor/l2norm.cu @@ -54,7 +54,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1, final = x[tid] + x[tid + 32]; else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); @@ -95,7 +95,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1, final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1) diff --git a/transformer_engine/common/normalization/kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h index 12fc095c38..01dc71a071 100644 --- a/transformer_engine/common/normalization/kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -92,7 +92,7 @@ struct Kernel_traits : public Base { enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed - enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA* COLS * sizeof(compute_t) }; + enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); using reduce_t = typename transformer_engine::TypeToVec2::Type; diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index c4b00b87c3..cb60e6173b 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -504,9 +504,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne template -__global__ -__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel( - BackwardKernelParams params) { +__global__ __launch_bounds__( + WARPS_M * WARPS_N * + THREADS_PER_WARP) void ln_bwd_finalize_general_kernel(BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 68aa0942c1..2579673a4e 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -136,8 +136,8 @@ void launch_ln_bwd_general_(LaunchParams &launch_params, OTYPE, CTYPE, ...) \ namespace { \ void \ - norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ launch_ln_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 464df8d276..eda24141de 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -105,8 +105,8 @@ void launch_ln_fwd_general_(LaunchParams &launch_params, OTYPE, CTYPE, ...) \ namespace { \ void \ - norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ launch_ln_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh index d620ee5260..7446a4c7c3 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -469,9 +469,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_ template -__global__ -__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel( - BackwardKernelParams params) { +__global__ __launch_bounds__( + WARPS_M * WARPS_N * + THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 60238f256d..023071fe4a 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -137,8 +137,8 @@ void launch_rmsnorm_bwd_general_(LaunchParams &launch_para OTYPE, CTYPE, ...) \ namespace { \ void \ - norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ launch_rmsnorm_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 5522fd5c6b..b8ef52ba56 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -106,8 +106,8 @@ void launch_rmsnorm_fwd_general_(LaunchParams &launch_param OTYPE, CTYPE, ...) \ namespace { \ void \ - norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ launch_rmsnorm_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..ca38f62bab 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """This module provides predefined FP8 recipes.""" + from __future__ import annotations import os from enum import Enum diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index c5ad1aed43..1c596f9a44 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -68,7 +68,7 @@ struct no_oob_tag_t {}; constexpr no_oob_tag_t NO_OOB_TAG; template -void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) +void __global__ __launch_bounds__(WARPS_X_PER_TB * WARPS_Y_PER_TB * WARP_SIZE) swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel( const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride, @@ -167,7 +167,7 @@ namespace swizzle_kernel_2d { constexpr uint32_t WARPS_X_PER_TB = 2; // configurable constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable -void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) +void __global__ __launch_bounds__(WARPS_X_PER_TB * WARPS_Y_PER_TB * WARP_SIZE) swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel( const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index e53b2a9455..d72d4750cc 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -11,7 +11,6 @@ from triton.language.standard import _log2 from packaging import version - # The following three argsort related kernels are adapted from # the issue https://github.com/triton-lang/triton/issues/3698 diff --git a/transformer_engine/common/utils.py b/transformer_engine/common/utils.py index acbb1ca5fb..90d7df9e86 100644 --- a/transformer_engine/common/utils.py +++ b/transformer_engine/common/utils.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """The utilities for Transformer Engine""" + import inspect import warnings from enum import Enum diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index ffcc6b1ad4..3ef39eba8f 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -24,7 +24,6 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter - ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 9ce56dd76d..af98ea23e1 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -8,7 +8,6 @@ When log() is called, they gather stats from all nodes, compute combined final stats and log them. """ - from collections import defaultdict from typing import Dict import torch diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index d0afc1ff25..be85c3f521 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -42,7 +42,6 @@ from ..common.utils import deprecate_wrapper from ..common.utils import DeprecatedEnum - __all__ = [ "NVTE_FP8_COLLECTION_NAME", "autocast", diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index b2b90a10c9..cc0a48e52d 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -96,7 +96,7 @@ def _activation_bwd_rule(activation_type, act_params, ctx, g): Returns: Gradient with respect to input """ - (x, _) = ctx + x, _ = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 21db296c34..434a209cd8 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX multi-head attention modules""" + from __future__ import annotations from enum import Enum from functools import partial diff --git a/transformer_engine/jax/checkpoint_policies.py b/transformer_engine/jax/checkpoint_policies.py index 7312eefb11..e0463c3578 100644 --- a/transformer_engine/jax/checkpoint_policies.py +++ b/transformer_engine/jax/checkpoint_policies.py @@ -9,7 +9,6 @@ import jax from .cpp_extensions.gemm import GemmPrimitive, GroupedGemmPrimitive - __all__ = [ "te_gemms_saveable", "dots_and_te_gemms_with_no_batch_dims", diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 6a2f9b7378..193c23800b 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Python interface for c++ extensions""" + from .activation import * from .amax import * from .attention import * diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 573603ef3a..c4d83912cb 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for activation""" + from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial @@ -37,7 +38,6 @@ QuantizeLayout, ) - __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] @@ -726,7 +726,7 @@ def outer_abstract(*args, **kwargs): """ te_dact_dbias_quantize_p outer abstract """ - (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( + out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _ = ( BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs) ) return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias @@ -808,7 +808,7 @@ def impl( """ del is_outer assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None - (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( + out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _ = ( BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind( dz, x, @@ -1058,7 +1058,7 @@ def partition( ) def sharded_impl(dz, x, scale, amax): - (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = ( + out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias = ( BaseDActLuDBiasQuantizePrimitive.impl( dz, x, diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py index 700ba9061c..6023d320d3 100644 --- a/transformer_engine/jax/cpp_extensions/amax.py +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for amax calculation""" + from enum import Enum @@ -25,7 +26,6 @@ get_sign_from_vector, ) - __all__ = ["AmaxScope", "calculate_amax", "calculate_post_rht_amax"] diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 0cdfcebf38..c1e7775de6 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for attention""" + import operator import os import warnings @@ -49,7 +50,6 @@ with_sharding_constraint, ) - __all__ = [ "FusedAttnHelper", "fused_attn_fwd", @@ -2685,7 +2685,7 @@ def correction(output, softmax_aux, output_per_step, softmax_aux_per_step): else: for i in range(0, cp_size): carry = scan_kv_block(i, carry) - (kv, output, softmax_aux) = carry + kv, output, softmax_aux = carry output = output.astype(q.dtype) return output, softmax_aux, rng_state @@ -2909,7 +2909,7 @@ def jax_cond_wrap(): else: for i in range(0, cp_size): carry = scan_kv_block(i, carry) - (kv, dq, dk_dv, dbias) = carry + kv, dq, dk_dv, dbias = carry # Final permute to put gradients back to their final resting place. dk_dv = helper.permute_kv(dk_dv, cp_perm) @@ -3133,7 +3133,7 @@ def correction(output, softmax_aux, output_per_step, softmax_aux_per_step): else: for i in range(0, cp_size): carry = scan_kv_block(i, carry) - (_, _, _, output, softmax_aux) = carry + _, _, _, output, softmax_aux = carry return output.astype(q.dtype), softmax_aux, rng_state @@ -3267,7 +3267,7 @@ def compute(config): else: for idx in range(cp_size): carry = scan_kv_block(idx, carry) - (_, _, _, dq, dkv, dbias) = carry + _, _, _, dq, dkv, dbias = carry # Final permute to put gradients back to their final resting place. dkv = helper.permute_kv(dkv, cp_perm) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index b26e01c0c7..ba8ebf1473 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE base custom ops""" + import os import re import warnings diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..b9f1570ef5 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -53,7 +53,6 @@ dp_or_fsdp_axis_size, ) - __all__ = [ "CollectiveOp", "CollectiveOpSet", @@ -698,7 +697,7 @@ def impl( reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) lhs = reordered.reshape(original_shape) - (output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind( + output, bias_grad, pre_gelu_out, _ = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, rhs, @@ -1015,10 +1014,8 @@ def infer_sharding_from_operands( sequence_dim, ) - (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op - ) + _, (out_specs, dbias_specs, pre_gelu_specs), *_ = GemmPrimitive._parse_operand_output_specs( + arg_infos, contracting_dims, transpose_batch_sequence, collective_op ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -1187,7 +1184,7 @@ def _generate_operand_rules(name, ndim, cdims): lhs, _, rhs, *_ = operand_types operand_ndims = (len(lhs.shape), len(rhs.shape)) - (lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_cdims, rhs_cdims = map(sanitize_dims, operand_ndims, contracting_dims) lhs_specs, rhs_specs = map( _generate_operand_rules, ("lhs", "rhs"), @@ -1502,7 +1499,7 @@ def abstract( @staticmethod def outer_abstract(*args, **kwargs): - (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) + out_aval, _ = GroupedGemmPrimitive.abstract(*args, **kwargs) return (out_aval,) @staticmethod @@ -1556,7 +1553,7 @@ def impl( use_async_d2h_group_sizes, ): assert GroupedGemmPrimitive.inner_primitive is not None - (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + out, _ = GroupedGemmPrimitive.inner_primitive.bind( lhs_data, lhs_scale_inv, rhs_data, diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 70fdf4c474..e272d6a316 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for normalization""" + import os import warnings import operator @@ -40,7 +41,6 @@ QuantizeLayout, ) - __all__ = [ "layernorm_fwd", "layernorm_bwd", diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 1fcecb0e96..0d015561cc 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for quantization""" + import operator from functools import reduce from typing import Tuple, Optional, Union @@ -47,7 +48,6 @@ QuantizeLayout, ) - __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index ff30c9bba3..444f1b69e6 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for softmax""" + from abc import abstractmethod from functools import partial, reduce import operator @@ -17,7 +18,6 @@ from .misc import get_padded_spec, check_valid_batch_dims from ..softmax import SoftmaxFusionType - __all__ = [ "scaled_softmax_fwd", "scaled_softmax_bwd", diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index dd7d2a47ba..5981a5b02d 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Transformer Engine bindings for JAX""" + from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP from .module import wrap_function_in_te_state_module, make_dot_general_cls diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3d82d8f0b4..8edcad2821 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -4,6 +4,7 @@ """ Wrapper module for Transformer related layers with FP8 support. """ + from functools import reduce import operator from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index ad5a60e4c2..8760e3a219 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -4,6 +4,7 @@ """ Wrapper module for Transformer related layers with FP8 support. """ + import functools from enum import Enum from math import sqrt diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index 4505611a48..bfea1d5989 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -9,6 +9,7 @@ It exports all the necessary classes and functions from the underlying implementation modules. """ + from .tensor import * from .quantizer import * from .dequantizer import * diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 74787b9308..9b52afad4d 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -7,6 +7,7 @@ This module provides utilities for dequantizing tensors that have been quantized using various scaling modes, including delayed scaling and block scaling. """ + import math from dataclasses import dataclass from abc import ABC, abstractmethod @@ -17,7 +18,6 @@ from .scaling_modes import ScalingMode from .hadamard import apply_rht - __all__ = ["ScalingModeToDequantizerMap"] diff --git a/transformer_engine/jax/quantize/hadamard.py b/transformer_engine/jax/quantize/hadamard.py index 1bad6be101..efe975f3df 100644 --- a/transformer_engine/jax/quantize/hadamard.py +++ b/transformer_engine/jax/quantize/hadamard.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Randomized Hadamard Transform (RHT) utilities for JAX.""" + import jax.numpy as jnp diff --git a/transformer_engine/jax/quantize/metadata.py b/transformer_engine/jax/quantize/metadata.py index 52367216c4..77c0414ce8 100644 --- a/transformer_engine/jax/quantize/metadata.py +++ b/transformer_engine/jax/quantize/metadata.py @@ -8,8 +8,8 @@ This module provides classes for managing quantization metadata, including scale factors and amax history for different tensor types. """ -from dataclasses import dataclass +from dataclasses import dataclass __all__ = ["QuantizeMeta", "QuantizeMetaSet"] diff --git a/transformer_engine/jax/quantize/misc.py b/transformer_engine/jax/quantize/misc.py index b7841bfa4e..c9899da15d 100644 --- a/transformer_engine/jax/quantize/misc.py +++ b/transformer_engine/jax/quantize/misc.py @@ -4,6 +4,7 @@ """ This module provides additional enum and utilities for quantizing tensors in JAX. """ + from dataclasses import dataclass from enum import Enum diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index f5ca6aeaed..b6ed38a2ee 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -6,6 +6,7 @@ This module provides classes and utilities for quantizing tensors in JAX. """ + from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import partial diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 61c3af178c..7e20d7915e 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -25,7 +25,6 @@ from .misc import QuantizeLayout from .device_utils import is_fp8_gemm_with_all_layouts_supported - __all__ = [ "QuantizeShardyRules", "ScalingMode", diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c26cb8a531..d86a151f8e 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -8,6 +8,7 @@ both single-scale (1x) and double-scale (2x) quantization schemes. It supports rowwise and colwise quantization modes with proper scaling and dequantization. """ + from dataclasses import dataclass from typing import Callable, Tuple from abc import ABC, abstractmethod diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c14..e05347ea22 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -9,6 +9,7 @@ parallelism (FSDP). It includes functions for sharding constraints, mesh management, and collective operations. """ + from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, Optional diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index 8302e7ccee..ab1b060aa6 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX softmax modules""" + from enum import Enum from functools import partial from typing import Optional @@ -53,7 +54,7 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_fusion_type): def _softmax_bwd_rule(scale_factor, softmax_fusion_type, ctx, dz): - (softmax_output, logits, mask) = ctx + softmax_output, logits, mask = ctx if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED: dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 849673fe31..9d22c9b97a 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -23,7 +23,6 @@ ) from .utils import triton_call_lowering - __all__ = [ "make_row_id_map", "permute_with_mask_map", diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 064b2843c6..0ae57222f7 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -16,7 +16,6 @@ import jax import jax.numpy as jnp - try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c726ed8849..b0394cab84 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Attention Backends.""" + from contextlib import nullcontext from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 75b360e485..57a6172d7b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Context Parallelism.""" + import os import itertools from typing import List, Union, Tuple @@ -3046,7 +3047,7 @@ def backward(ctx, dout, *_args): rank = get_distributed_rank(ctx.cp_group) (*saved_tensors,) = ctx.saved_tensors - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + q, k, v, cu_seqlens_q, cu_seqlens_q_padded = saved_tensors[:5] cu_seqlens_kv_per_step = saved_tensors[5:7] out_per_step = saved_tensors[7:9] softmax_lse_per_step = saved_tensors[9:11] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6e5a12a103..26d9cac9a8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Attention.""" + from contextlib import nullcontext import math import os @@ -61,7 +62,6 @@ FlashAttention, ) - # Setup Attention Logging attn_log.setup_logging() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py index 74d9583ce5..8aa6954be9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Fused scaled masked softmax functions""" + import os from typing import Callable, Tuple, Union, Optional import torch @@ -10,7 +11,6 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.export import is_in_onnx_export_mode - THREADS_PER_WARP = 32 THREADS_PER_BLOCK = 128 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..7bcc4dc257 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -5,6 +5,7 @@ """ Utils/Helper classes and methods for attention """ + import math import os from typing import Any, Dict, List, Optional, Tuple, Union @@ -352,8 +353,9 @@ def get_attention_backend( cudnn_version = get_cudnn_version() run_config = { "transformer_engine_version": te.__version__, - "compute_capability": "sm" - + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "compute_capability": ( + "sm" + str(10 * device_compute_capability[0] + device_compute_capability[1]) + ), "flash_attn_version": ( str(FlashAttentionUtils.version) if FlashAttentionUtils.is_installed diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 08e50aad8b..70b9c04ab5 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Inference""" + import logging from collections import OrderedDict, defaultdict from typing import Optional, List diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f875fd1e0a..881128af1b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Multi-head Attention.""" + import os import collections from typing import Callable, List, Optional, Tuple, Union diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 77ad57ed8f..e2d4cb914e 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,13 +5,13 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ + from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat - __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 3cce4600d9..4e726622d1 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -3,11 +3,11 @@ # See LICENSE for license information. """Enums for e2e transformer""" + import torch import torch.distributed import transformer_engine_torch as tex - """ This is a map: torch.dtype -> int Used for passing dtypes into cuda diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index bb6e921132..995ff2e3dc 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for c++ extensions""" + from transformer_engine_torch import * from .fused_attn import * diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index e226ef32d4..2205d52677 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for fused attention extensions""" + import math from typing import Tuple, List, Union, Optional import torch @@ -17,7 +18,6 @@ ) from ..quantized_tensor import Quantizer - __all__ = [ "fused_attn_fwd", "fused_attn_bwd", diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 2a97e2ac71..60e58ced41 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -18,7 +18,6 @@ from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer - __all__ = [ "general_gemm", "general_grouped_gemm", diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index d0b8d3474e..24da8ac3ee 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -21,7 +21,6 @@ prepare_for_saving, ) - __all__ = ["get_cpu_offload_context", "mark_not_offload", "start_offload"] NVTE_CPU_OFFLOAD_V1 = os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1" diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py index f92c436941..a81fd1f374 100644 --- a/transformer_engine/pytorch/cpu_offload_v1.py +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Functionality for CPU offloading of tensors saved for backward pass.""" + from __future__ import annotations from contextlib import nullcontext from typing import Any, Dict, Optional diff --git a/transformer_engine/pytorch/custom_recipes/utils.py b/transformer_engine/pytorch/custom_recipes/utils.py index 3e23661f14..fe5a6ba499 100644 --- a/transformer_engine/pytorch/custom_recipes/utils.py +++ b/transformer_engine/pytorch/custom_recipes/utils.py @@ -8,7 +8,6 @@ import torch - HIGH_PRECISION_FLOAT_DTYPES = ( torch.float, torch.float16, diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 5497ee7967..6a39717ffd 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Methods needed for distributed training (DP/TP).""" + from __future__ import annotations from collections.abc import Iterable @@ -50,7 +51,6 @@ from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer - __all__ = ["checkpoint", "CudaRNGStatesTracker"] diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py index 89306fbe1e..c50dfe1676 100644 --- a/transformer_engine/pytorch/export.py +++ b/transformer_engine/pytorch/export.py @@ -8,7 +8,6 @@ from typing import Generator import torch - _IN_ONNX_EXPORT_MODE = False TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 6bcf2d53c7..af3734c91f 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -33,7 +33,6 @@ CustomRecipe, ) - # Importing each function instead of 'import *' allows us specify '__all__' in # quantize.py and also makes any newer additions to quantize.py invisible via # fp8.py so that we don't reinforce importing internal TE functions. diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..1689e35a44 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" + from collections.abc import Iterable import contextlib import gc diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 5884188b7e..048d514aa5 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """NVFuser functions and JIT utilities""" + import os from functools import wraps from typing import Callable, Optional, Tuple diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 3cf15efc11..80031bd6f5 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Module level PyTorch APIs""" + from .layernorm_linear import LayerNormLinear from .linear import Linear from .grouped_linear import GroupedLinear diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ad5cd04341..960dd3d8ab 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Base modules and utilities for TransformerEngine PyTorch API""" + import io import math import os diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 8ac49c9bae..e6ad5395da 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -13,7 +13,6 @@ from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization from ..jit import no_torch_dynamo - __all__ = ["Fp8Padding"] @@ -30,7 +29,7 @@ def forward( # Reduce number of arguments to autograd function in order # to reduce CPU overhead due to pytorch arg checking. - (m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args + m_splits, padded_m_splits, is_grad_enabled = non_tensor_args # Make sure input dimensions are compatible in_features = inp.shape[-1] diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index c5d396837f..6df36a41dc 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -13,7 +13,6 @@ from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization from ..jit import no_torch_dynamo - __all__ = ["Fp8Unpadding"] @@ -30,7 +29,7 @@ def forward( # Reduce number of arguments to autograd function in order # to reduce CPU overhead due to pytorch arg checking. - (m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args + m_splits, padded_m_splits, is_grad_enabled = non_tensor_args in_features = inp.shape[-1] diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d0a5618afb..58aba05405 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """GroupedLinear API""" + from typing import Union, Optional, Callable, Tuple, List from itertools import chain import warnings diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index d4f0a78ba2..9f54d020d4 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """LayerNorm API""" + import warnings from typing import Iterable, Optional, Union diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 13b94f2327..8faaf8167b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """LayerNormLinear API""" + import os import warnings from typing import Callable, Dict, Optional, Tuple, Union, List @@ -873,7 +874,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": bias if (grad_bias is None and not ctx.fp8) else None, "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ddb33f303c..da528e3e7d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """LayerNormMLP API""" + import os import warnings from typing import Callable, Optional, Tuple, Union, List @@ -2296,13 +2297,16 @@ def _clamped_swiglu(x, limit, alpha): "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), - "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") - * x.chunk(2, -1)[1], + "qgeglu": ( + lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") + * x.chunk(2, -1)[1] + ), "relu": torch.nn.functional.relu, "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "srelu": lambda x: torch.nn.functional.relu(x) ** 2, - "sreglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) ** 2 - * x.chunk(2, -1)[1], + "sreglu": ( + lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) ** 2 * x.chunk(2, -1)[1] + ), "silu": torch.nn.functional.silu, "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "clamped_swiglu": lambda x: _clamped_swiglu( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f3220d5860..7bf02fed7a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Linear API""" + from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op @@ -858,7 +859,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": bias if (grad_bias is None and not ctx.fp8) else None, "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index ace4be31de..55d9d83c19 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """RMSNorm API""" + import warnings from typing import Iterable, Optional, Union diff --git a/transformer_engine/pytorch/numerics_debug.py b/transformer_engine/pytorch/numerics_debug.py index 45d9aacde3..4042ca6211 100644 --- a/transformer_engine/pytorch/numerics_debug.py +++ b/transformer_engine/pytorch/numerics_debug.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Utilities for debugging numerical issues with FP8""" + from typing import Tuple import torch from transformer_engine.common import recipe diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 2714d718fe..214a0dba42 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1008,7 +1008,7 @@ def op_backward( ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: # Saved tensors from forward pass - (x_local, w) = ctx.saved_tensors + x_local, w = ctx.saved_tensors # Megatron-LM wgrad fusion # Note: Get grad tensor from param so we can accumulate diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 5e7339db85..b175fe916c 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -48,7 +48,7 @@ def fuser_backward( linear_op_ctx = basic_op_ctxs[0] # Saved tensors from forward pass - (x_local, w) = linear_op_ctx.saved_tensors + x_local, w = linear_op_ctx.saved_tensors # Megatron-LM wgrad fusion # Note: Get grad tensor from param so we can accumulate diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index f7f59e65c9..4de93bd4db 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -49,7 +49,7 @@ def fuser_backward( scale_op = self.basic_ops[1] # Saved tensors from forward pass - (x_local, w) = linear_op_ctx.saved_tensors + x_local, w = linear_op_ctx.saved_tensors # Megatron-LM wgrad fusion # Note: Get grad tensor from param so we can accumulate diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 4943ffb1bd..8bd29e54fe 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -509,7 +509,7 @@ def fuser_backward( bias_op = self.basic_ops[idx] # Saved tensors from forward pass - (x_local, w) = linear_op_ctx.saved_tensors + x_local, w = linear_op_ctx.saved_tensors # Megatron-LM wgrad fusion # Note: Get grad tensor from param so we can accumulate diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index d5829b0c50..077cb79448 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -109,7 +109,7 @@ def __init__( "size": out_features, "device": device, "dtype": dtype, - "tensor_parallel": (tensor_parallel_mode is not None), + "tensor_parallel": tensor_parallel_mode is not None, "tensor_parallel_group": tensor_parallel_group, } if tensor_parallel_mode == "row": diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 792eab094a..8ffb016a87 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Fused optimizers and multi-tensor kernels.""" + from transformer_engine_torch import ( multi_tensor_scale, multi_tensor_l2norm, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 1995655c33..c6549c1846 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Fused Adam optimizer.""" + from __future__ import annotations from collections.abc import Iterable from copy import deepcopy diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index 08e465e951..af76ae6b97 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Fused SGD optimizer.""" + from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index a5cbd27337..a21fe1d75e 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Multi-tensor apply entry.""" + from torch.distributed._tensor import DTensor diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 0c16b35e11..9f6f922cf9 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """MoE Permutation API""" + import warnings from typing import Optional, Tuple import torch diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..81839ce9ba 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Quantization utilities for TransformerEngine""" + from __future__ import annotations import abc @@ -31,7 +32,6 @@ from .utils import get_device_compute_capability from .jit import jit_fuser - __all__ = [ "autocast", "quantized_model_init", diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 52d1d9d6ca..7544694bb5 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -4,6 +4,7 @@ """ Fused functions used in the MoE router """ + import torch import transformer_engine_torch as tex diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 99f6a99efa..1db179b904 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -53,7 +53,6 @@ test_requirements, ) - os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 03c16ebbed..ac460201b1 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Tensor class with FP8 data quantized with NxN tiles""" + from __future__ import annotations from typing import Optional, Tuple, Iterable, Union diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 43cbdcf9e6..8b4ca38705 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Tensor class with FP8 data""" + from __future__ import annotations from typing import Any, Optional, Tuple, Iterable, Union import warnings @@ -877,7 +878,7 @@ def fsdp_post_all_gather( """ (data,) = all_gather_outputs - (fp8_scale_inv, rowwise_usage, columnwise_usage, fp8_dtype) = metadata + fp8_scale_inv, rowwise_usage, columnwise_usage, fp8_dtype = metadata orig_shape = data.size() # Quantizer has only columnwise usage set for backward pass # In Blackwell+ architectures, transpose is not needed at all, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 88081f51bf..8389b3645b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Tensor class with MXFP8 data""" + from __future__ import annotations from collections.abc import Iterable import math diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 8b707af3b2..f568b3591a 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Tensor class with NVFP4 data""" + from __future__ import annotations from collections.abc import Iterable import math diff --git a/transformer_engine/pytorch/torch_version.py b/transformer_engine/pytorch/torch_version.py index 3e299af1fd..fec1576fd9 100644 --- a/transformer_engine/pytorch/torch_version.py +++ b/transformer_engine/pytorch/torch_version.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """PyTorch version utilities""" + from __future__ import annotations import functools import torch diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 9b9ccc5185..2078462ad3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Transformer.""" + import os import warnings from contextlib import nullcontext @@ -36,7 +37,6 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 47af9fabe1..c748fb98d6 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Utility functions for Transformer Engine modules""" + from __future__ import annotations import functools import math @@ -16,7 +17,6 @@ from .torch_version import torch_version from ..debug.pytorch.debug_quantization import DebugQuantizedTensor - __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"]