From fbb0702939387d7766dd0b7359511a38eef18d89 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 5 Jan 2026 11:10:11 -0800 Subject: [PATCH 1/8] Update THD sink attention logic for newer cudnn versions THD Sink attention is supported in 9.18.0 Signed-off-by: Chen Cui --- .../pytorch/attention/dot_product_attention/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..3749d40e37 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -716,10 +716,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False + if cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + ) + use_fused_attention = False logger.debug( "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", softmax_type, From 01848c0687b89ec87d586d5e1070772d7cf68ea8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:11:33 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 3749d40e37..fce04bfa2d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -718,7 +718,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", + softmax_type, ) use_fused_attention = False logger.debug( From ab13ba0aff7ab94669cd8939c35411a094f02e4c Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 6 Jan 2026 15:33:24 -0800 Subject: [PATCH 3/8] update thd sink attention logic for cp>1 Signed-off-by: Chen Cui --- .../dot_product_attention/context_parallel.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 75b360e485..a12ac9ae1a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4026,28 +4026,29 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" + ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "Context parallelism does not support MLA with {cp_comm_type=}!" + ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: assert ( softmax_type == "vanilla" - ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention - ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" assert ( softmax_type == "vanilla" or cp_comm_type == "a2a" - ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + if get_cudnn_version() < (9, 18, 0): + assert ( + softmax_type == "vanilla" or qkv_format != "thd" + ), f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" args = [ is_training, From c2c4341ef27e4d1ee3c75fa008a9ee071fa6ac9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jan 2026 23:55:03 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/dot_product_attention/context_parallel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a12ac9ae1a..a5931188dc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4046,9 +4046,10 @@ def attn_forward_func_with_cp( softmax_type == "vanilla" or cp_comm_type == "a2a" ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" if get_cudnn_version() < (9, 18, 0): - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + assert softmax_type == "vanilla" or qkv_format != "thd", ( + f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with" + " qkv_format = 'thd'!" + ) args = [ is_training, From 392a0334b84679dcb53ac0966460eef9e3f75392 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 7 Jan 2026 14:54:44 -0800 Subject: [PATCH 5/8] add unit test for thd + sink attention Signed-off-by: Chen Cui --- tests/pytorch/attention/test_attention.py | 11 +++++++++++ .../pytorch/attention/dot_product_attention/utils.py | 5 ----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..a81ee34ab4 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -419,6 +419,17 @@ def test_dpa_softmax(dtype, model_configs, model): ) +@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax_thd(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention( + dtype, model_configs, model, True, True, "thd_thd_thd", False, False + ) + + model_configs_mla = { # test: ModelConfig(b, sq, hq, dqk) "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index fce04bfa2d..8f8f4f4621 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -722,11 +722,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt softmax_type, ) use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", - softmax_type, - ) - use_unfused_attention = False if context_parallel: logger.debug( "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" From a75af446b7f2e59be2cbd0d4a712ef84e5b6d067 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jan 2026 22:55:28 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a81ee34ab4..de6f983e6f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -425,9 +425,7 @@ def test_dpa_softmax(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_softmax.keys()) def test_dpa_softmax_thd(dtype, model_configs, model): """Test DotProductAttention module with different softmax types""" - test_dot_product_attention( - dtype, model_configs, model, True, True, "thd_thd_thd", False, False - ) + test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False) model_configs_mla = { From 200632f6ab17b987832c93d9fb61f2a72f7f3366 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 8 Jan 2026 15:42:15 -0800 Subject: [PATCH 7/8] address comments Signed-off-by: Chen Cui --- .../pytorch/attention/dot_product_attention/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8f8f4f4621..ac0d2bb400 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -718,17 +718,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN version < 9.18", softmax_type, ) use_fused_attention = False if context_parallel: - logger.debug( - "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" - " = %s", - softmax_type, - ) - use_unfused_attention = False if cp_comm_type != "a2a": logger.debug( "Disabling FusedAttention for context parallelism with softmax_type = %s and" From 7f4333abc3ee23e330222703cbc0d127f8afdaf2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 23:43:10 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ac0d2bb400..097a3b60e5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -718,7 +718,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN version < 9.18", + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", softmax_type, ) use_fused_attention = False