Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v6.0.0
hooks:
- id: check-merge-conflict
- id: check-added-large-files
Expand All @@ -23,24 +23,24 @@ repos:
- id: trailing-whitespace
files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$

- repo: https://github.com/psf/black
rev: 24.4.2
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.12.0
hooks:
- id: black
name: Format python code
args: [--line-length=100, --preview, --enable-unstable-feature=string_processing]
types: [python]

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.6
rev: v21.1.8
hooks:
- id: clang-format
entry: clang-format -i
args: ["-style=file"]
files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$

- repo: https://github.com/netromdk/vermin
rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3
rev: v1.8.0
hooks:
- id: vermin
args: ['-t=3.10', '--violations']
1 change: 1 addition & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""JAX related extensions."""

import os
from pathlib import Path
from packaging import version
Expand Down
1 change: 1 addition & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""PyTorch related extensions."""

import os
from pathlib import Path

Expand Down
1 change: 1 addition & 0 deletions build_tools/te_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""Transformer Engine version string."""

import os
from pathlib import Path
import subprocess
Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""config for collective_gemm tests"""

import pytest


Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/test_dense_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""

import argparse
import unittest
import os
Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Example:
python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3
"""

import unittest
import os
from functools import partial
Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""

import argparse
import unittest
import os
Expand Down
1 change: 1 addition & 0 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""

from functools import lru_cache
import os
import pathlib
Expand Down
1 change: 1 addition & 0 deletions examples/jax/encoder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""config for test_multiprocessing_encoder"""

import pytest


Expand Down
1 change: 1 addition & 0 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Encoder training on multi-GPU with tesnor parallelism"""

import argparse
import unittest
from functools import partial
Expand Down
1 change: 1 addition & 0 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Encoder training on multi-GPU with data parallelism"""

import argparse
import unittest
from functools import partial
Expand Down
5 changes: 3 additions & 2 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""

import argparse
import os
import unittest
Expand Down Expand Up @@ -104,7 +105,7 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False
inputs = jnp.asarray(dataset)
total_input_size = len(inputs)

(dp_size, tp_size) = mesh.device_ids.shape
dp_size, tp_size = mesh.device_ids.shape
valid_input_size, global_batch_size, num_steps, tp_group_id = valid_shard_size(
total_input_size, batch_size, dp_size, tp_size
)
Expand Down Expand Up @@ -156,7 +157,7 @@ def train_epoch(
"""Train for a single epoch."""

total_batch_size = len(train_ds["sentence"])
(dp_size, tp_size) = mesh.device_ids.shape
dp_size, tp_size = mesh.device_ids.shape
valid_size, _, num_steps, tp_group_id = valid_shard_size(
total_batch_size, batch_size, dp_size, tp_size
)
Expand Down
1 change: 1 addition & 0 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Encoder training on single GPU"""

import argparse
import unittest
from functools import partial
Expand Down
1 change: 1 addition & 0 deletions examples/jax/mnist/test_single_gpu_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""MNIST training on single GPU"""

import argparse
import unittest
from functools import partial
Expand Down
1 change: 1 addition & 0 deletions tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""conftest for tests/jax"""

import os
import jax
import pytest
Expand Down
1 change: 0 additions & 1 deletion tests/jax/test_distributed_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from transformer_engine.jax import autocast
from transformer_engine.jax.dense import dense


DTYPES = [jnp.bfloat16]

GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in]
Expand Down
1 change: 0 additions & 1 deletion tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
ReorderStrategy,
)


DTYPES = [jnp.bfloat16]

DISTRIBUTED_SELF_ATTN_DATA_SHAPES = {
Expand Down
1 change: 0 additions & 1 deletion tests/jax/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available


DTYPES = [jnp.bfloat16, jnp.float32]

NORM_INPUT_SHAPES = {
Expand Down
1 change: 0 additions & 1 deletion tests/jax/test_distributed_layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
)
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability


is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""

from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
Expand Down
3 changes: 2 additions & 1 deletion tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Test transformer_engine.jax.flax.TransformerLayer"""

import os
from functools import partial
from typing import Dict, Tuple, Optional
Expand Down Expand Up @@ -209,7 +210,7 @@ def enable_fused_attn():
},
# attrs20
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
_KEY_OF_MLP_ACTIVATIONS: ("relu", "relu"),
},
# attrs21
{
Expand Down
1 change: 0 additions & 1 deletion tests/jax/test_multi_process_distributed_grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from utils import assert_allclose, dtype_tols


N_GROUP = 8
MESH_AXIS_NAME = "fsdp"

Expand Down
1 change: 0 additions & 1 deletion tests/jax/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from utils import assert_allclose, pytest_parametrize_wrapper


# =============================================================================
# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels
# =============================================================================
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""

from contextlib import nullcontext
from dataclasses import dataclass
from functools import wraps
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2629,7 +2629,7 @@ def forward(
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
q, k, v, inp_fp8, qkv_weight_fp8, out = restore_from_saved(
ctx.tensor_objects, saved_tensors
)

Expand Down
1 change: 1 addition & 0 deletions tests/pytorch/attention/test_cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""Unit tests for context parallel utils."""

import torch
import unittest
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
Expand Down
18 changes: 6 additions & 12 deletions tests/pytorch/debug/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,15 @@ def _get_tensors():
"""


DISABLE_FP8_CONFIG = Template(
"""disable_fp8_config:
DISABLE_FP8_CONFIG = Template("""disable_fp8_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [$gemms]
"""
)
""")


@create_config_file
Expand Down Expand Up @@ -407,8 +405,7 @@ def test_per_tensor_scaling(
)


PER_TENSOR_SCALING_CONFIG = Template(
"""per_tensor_scaling_config:
PER_TENSOR_SCALING_CONFIG = Template("""per_tensor_scaling_config:
enabled: True
layers:
layer_types: [linear]
Expand All @@ -417,8 +414,7 @@ def test_per_tensor_scaling(
enabled: True
gemms_struct:
$gemms
"""
)
""")


def _prepare_per_tensor_scaling_config(
Expand Down Expand Up @@ -670,8 +666,7 @@ def test_fake_quant_fp8(
)


FAKE_QUANT_CONFIG = Template(
"""fake_quant_config:
FAKE_QUANT_CONFIG = Template("""fake_quant_config:
enabled: True
layers:
layer_types: [linear]
Expand All @@ -680,8 +675,7 @@ def test_fake_quant_fp8(
enabled: True
gemms_struct:
$gemms
"""
)
""")


def fake_quant_fp8_create_config(
Expand Down
12 changes: 8 additions & 4 deletions tests/pytorch/debug/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

configs = {
"": "",
"log": """log:
"log": (
"""log:
layers:
layer_types: [linear]
enabled:
Expand All @@ -36,8 +37,10 @@
stats: [underflows, overflows]
start_step : 0
end_step: 1
""",
"fake_quant": """
"""
),
"fake_quant": (
"""
fake_quant_config:
enabled: True
layers:
Expand All @@ -47,7 +50,8 @@
enabled: True
gemms: [fprop, dgrad, wgrad]
quant_format: FP8E5M2
""",
"""
),
}


Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/distributed/run_numerics_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from transformer_engine.pytorch.custom_recipes import utils
from run_layer_with_overlap import _compare_tensors


BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/distributed/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe, quantization_tols


# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
Float8Tensor,
)


# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch


fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count()
Expand Down
Loading
Loading