diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index d546118ffb5..47e0fbddd78 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -190,13 +190,13 @@ def test(): return available_backends, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - with logging_context(): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + # with logging_context(): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, fused_attn_backends diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5d3e1d60979..7f77ddadf5b 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -205,9 +205,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (window_size_right == -1 || window_size_right == 0)) || // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((window_size_left == -1 && window_size_right == -1 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && @@ -217,7 +219,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((window_size_left >= 0 || window_size_left == -1) && + (window_size_right >= 0 || window_size_right == -1) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && @@ -273,7 +276,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { + bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -324,9 +328,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_QKV, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -334,7 +338,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_QKV, input_output_S, output_O, + bias_type, attn_mask_type, window_size_left, window_size_right, + bottom_right_diagonal, input_QKV, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -345,15 +350,13 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } } // NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -414,9 +417,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } fused_attn_arbitrary_seqlen_bwd_qkvpacked( b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, - input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_QKV, + input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -429,9 +432,10 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, - input_S, input_output_dP, output_dQKV, input_cu_seqlens, - input_rng_state, wkspace, stream, handle); + attn_mask_type, window_size_left, window_size_right, + bottom_right_diagonal, input_QKV, input_O, input_dO, input_M, + input_ZInv, input_S, input_output_dP, output_dQKV, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -440,16 +444,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); @@ -507,10 +509,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, - input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + bottom_right_diagonal, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -519,8 +521,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, + input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -536,8 +539,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); @@ -607,10 +610,11 @@ void nvte_fused_attn_bwd_kvpacked( } fused_attn_arbitrary_seqlen_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, + output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -623,9 +627,10 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, + qkv_layout, bias_type, attn_mask_type, window_size_left, + window_size_right, bottom_right_diagonal, input_Q, input_KV, + input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, + output_dQ, output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); @@ -643,8 +648,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); @@ -696,9 +701,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + bottom_right_diagonal, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -706,7 +711,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, + window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -726,8 +732,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); @@ -791,10 +797,11 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, - output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, + input_Bias, output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -807,10 +814,11 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + bottom_right_diagonal, input_Q, input_K, input_V, input_O, input_dO, input_M, + input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 20467af663f..97c8d99d383 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -53,23 +53,29 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, - void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, - size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, + cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + bool is_causal_bottom_right = + ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + if (is_causal_bottom_right && s_q == s_kv) { is_causal = true; - is_bottom_right = false; + is_causal_bottom_right = false; + } + if (is_causal || is_causal_bottom_right) { + window_size_right = 0; } + bottom_right_diagonal = is_causal_bottom_right && !is_causal && is_causal_bottom_right; bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); @@ -107,6 +113,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( mask_type, window_size_left, window_size_right, + bottom_right_diagonal, true, tensorType, tensorType}; @@ -222,12 +229,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") .set_is_inference(false) - .set_causal_mask(is_causal) - .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_options.set_sliding_window_length(window_size_left + 1); + sdpa_options.set_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_right_bound(window_size_right); } sdpa_options.set_alibi_mask(is_alibi); @@ -432,10 +444,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, + void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -445,12 +457,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + bool is_causal_bottom_right = + ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + if (is_causal_bottom_right && s_q == s_kv) { is_causal = true; - is_bottom_right = false; + is_causal_bottom_right = false; + } + if (is_causal || is_causal_bottom_right) { + window_size_right = 0; } + bottom_right_diagonal = is_causal_bottom_right && !is_causal && is_causal_bottom_right; bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); @@ -492,6 +509,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( mask_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, tensorType, tensorType}; @@ -657,8 +675,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( fe::graph::SDPA_backward_attributes sdpa_backward_options; sdpa_backward_options = fe::graph::SDPA_backward_attributes() .set_name("flash_attention_backward") - .set_causal_mask(is_causal) - .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); if (is_ragged && cudnn_runtime_version >= 90600) { @@ -666,8 +682,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_backward_options.set_sliding_window_length(window_size_left + 1); + sdpa_backward_options.set_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_right_bound(window_size_right); } if (cudnn_runtime_version >= 90000) { @@ -890,9 +913,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_QKV->data.dtype; @@ -990,9 +1014,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, - devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1014,10 +1038,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool bottom_right_diagonal, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, + Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_QKV->data.dtype; @@ -1074,11 +1098,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, - devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, + devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1099,11 +1123,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1206,10 +1230,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, - devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1231,9 +1256,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1297,11 +1322,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, + devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1323,11 +1348,12 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1420,10 +1446,11 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, - devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1445,10 +1472,11 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1499,11 +1527,11 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, + devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 687928d0806..0924cd96d54 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -22,38 +22,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + bool bottom_right_diagonal, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, + Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -63,21 +64,23 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 0044a94b2fe..d82619b85e4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,12 +1652,14 @@ void fused_attn_fp8_bwd_impl( void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, + void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, + void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, + void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t fwd_tensor_type, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); @@ -1689,8 +1691,9 @@ void fused_attn_fp8_fwd_impl_v1( layout, bias_type, mask_type, - 0, - 0, + window_size_left, + window_size_right, + bottom_right_diagonal, true, fwd_tensor_type, fwd_tensor_type}; @@ -1952,7 +1955,8 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -1993,8 +1997,9 @@ void fused_attn_fp8_bwd_impl_v1( layout, bias_type, mask_type, - 0, - 0, + window_size_left, + window_size_right, + bottom_right_diagonal, false, fwd_tensor_type, bwd_tensor_type}; @@ -2349,14 +2354,13 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_fwd_qkvpacked( + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, + NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, const Tensor* rng_state, + Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; void* devPtrQKV = input_QKV->data.dptr; @@ -2422,11 +2426,12 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, + devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, + devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, @@ -2454,6 +2459,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma void fused_attn_fp8_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, @@ -2514,13 +2520,14 @@ void fused_attn_fp8_bwd_qkvpacked( if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, + devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, + devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, + devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, @@ -2547,15 +2554,14 @@ void fused_attn_fp8_bwd_qkvpacked( } } // fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, - Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { +void fused_attn_fp8_fwd_kvpacked( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor* input_Q, const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, + NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; void* devPtrQ = input_Q->data.dptr; @@ -2623,9 +2629,10 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, + devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { @@ -2657,6 +2664,7 @@ void fused_attn_fp8_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, @@ -2720,13 +2728,14 @@ void fused_attn_fp8_bwd_kvpacked( if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, + devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, + devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, + devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, @@ -2757,11 +2766,13 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, + const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2821,9 +2832,10 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, + devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { @@ -2854,14 +2866,15 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, - const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, - const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, - const Tensor* output_dK, const Tensor* output_dV, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, + const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, + const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, + Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, + const Tensor* output_dV, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2911,13 +2924,14 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, + devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, + devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, + devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 3daf45d1623..0a0bbd11e7d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -14,40 +14,40 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); +void fused_attn_fp8_fwd_qkvpacked( + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with packed QKV void fused_attn_fp8_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); +void fused_attn_fp8_fwd_kvpacked( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with packed KV void fused_attn_fp8_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, @@ -59,23 +59,26 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, - const Tensor *output_dK, const Tensor *output_dV, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, + Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, + const Tensor *output_dV, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index ed498049c3c..5feb56fc95f 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -103,6 +103,7 @@ struct FADescriptor_v1 { NVTE_Mask_Type mask_type; std::int64_t window_size_left; std::int64_t window_size_right; + bool bottom_right_diagonal; bool deterministic; cudnn_frontend::DataType_t fwd_tensor_type; cudnn_frontend::DataType_t bwd_tensor_type; @@ -110,11 +111,13 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, window_size_right, - deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + bottom_right_diagonal, deterministic, bias_type, fwd_tensor_type, + bwd_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, - rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); + rhs.mask_type, rhs.window_size_left, rhs.window_size_right, + rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, + rhs.fwd_tensor_type, rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index b9c8db1598f..eda98d85122 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -205,6 +205,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -215,7 +216,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); + bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -259,19 +261,18 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. * @@ -325,20 +326,19 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -388,6 +388,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -399,8 +400,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -458,6 +459,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -469,8 +471,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -525,6 +527,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -538,8 +541,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + NVTETensor workspace, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index dc857aa22c9..3145b6c0048 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -159,15 +159,15 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + window_size_right, True, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), - nullptr); + bias_type, mask_type, window_size_left, window_size_right, True, + query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), @@ -175,7 +175,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + window_size_right, True, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); } @@ -260,7 +260,7 @@ static void FusedAttnForwardImpl( q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + window_size_right, True, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; @@ -271,7 +271,8 @@ static void FusedAttnForwardImpl( &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + bias_type, mask_type, window_size_left, window_size_right, True, workspace_tensor.data(), + stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -285,7 +286,7 @@ static void FusedAttnForwardImpl( q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + window_size_left, window_size_right, True, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -463,7 +464,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, + bias_type, mask_type, window_size_left, window_size_right, True, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( @@ -474,7 +475,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), + window_size_left, window_size_right, True, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), @@ -486,7 +487,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, + window_size_left, window_size_right, True, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); @@ -543,7 +544,7 @@ static void FusedAttnBackwardImpl( &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, + bias_type, mask_type, window_size_left, window_size_right, True, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -564,7 +565,7 @@ static void FusedAttnBackwardImpl( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + True, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -589,7 +590,7 @@ static void FusedAttnBackwardImpl( kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); + window_size_right, True, deterministic, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9f08f67304c..ecccee2e87f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -256,6 +256,9 @@ class AttentionParams: Attention bias shape, {`1hss`, `b1ss`, `bhss`}. core_attention_bias_requires_grad: bool, default = `True` Whether attention bias requires gradient. + bottom_right_diagonal: bool, default = `True` + Whether to align sliding window and ALiBi diagonal to the bottom right corner + of the softmax matrix. pad_between_seqs: bool, default = `False` Whether there is padding between sequences in a batch. This only applies to `qkv_format=thd`. @@ -289,6 +292,7 @@ class AttentionParams: core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" core_attention_bias_requires_grad: bool = True + bottom_right_diagonal: bool = True pad_between_seqs: bool = False attention_dropout: float = 0.0 context_parallel: bool = False @@ -303,7 +307,10 @@ class AttentionParams: "_alibi_slopes": None, "_max_seqlen_q": None, "_max_seqlen_kv": None, - "_bottom_right_alignment": True, + "_bias_dtype": None, + "_actual_seqlens_q": None, + "_actual_seqlens_kv": None, + "_bottom_right_diagonal": True, "_alibi_bias": None, "_alibi_slopes_require_update": False, "_alibi_bias_require_update": False, @@ -358,6 +365,7 @@ def get_attention_backend( core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_shape = attention_params.core_attention_bias_shape core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad + bottom_right_diagonal = attention_params.bottom_right_diagonal pad_between_seqs = attention_params.pad_between_seqs attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel @@ -685,12 +693,12 @@ def get_attention_backend( _use_flash_attn_3 = False # Filter: Sliding window attention - # backend | window_size | diagonal alignment + # backend | window_size (left, right) | diagonal alignment # --------------------------------------------------------------------------------- - # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; - # | | converts window_size to an 'arbitrary' mask + # FlashAttention | (-1 or >=0, -1 or >=0) | bottom right + # FusedAttention | (-1 or >=0, -1 or >=0) | top left and bottom right + # UnfusedDotProductAttention | (-1 or >=0, -1 or >=0) | top left and bottom right; + # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) else: @@ -701,10 +709,10 @@ def get_attention_backend( " for FP8" ) use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0: + elif attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " - "with (left, 0) and no dropout" + "with no dropout" ) use_fused_attention = False elif max_seqlen_q > max_seqlen_kv: @@ -716,9 +724,11 @@ def get_attention_backend( if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): if _use_flash_attn_3: logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" + "Disabling FlashAttention 3 as it only supports sliding window with bottom" + " right diagonal alignment for cross-attention" ) _use_flash_attn_3 = False + if not _use_flash_attn_3: if not _flash_attn_is_installed: _flash_attn_version_required = PkgVersion("2.3") elif not _flash_attn_2_3_plus: @@ -726,6 +736,12 @@ def get_attention_backend( "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" ) use_flash_attention = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports sliding window with bottom right" + " diagonal alignment for cross-attention" + ) + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -745,6 +761,12 @@ def get_attention_backend( elif not _flash_attn_2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" + " alignment for cross-attention" + ) + use_flash_attention = False if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -1177,7 +1199,7 @@ def get_alibi( actual_seqlens_kv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, bias_dtype: Optional[torch.dtype] = None, - bottom_right_alignment: bool = True, + bottom_right_diagonal: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Parameters @@ -1196,7 +1218,7 @@ def get_alibi( Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. bias_dtype: Optional[torch.dtype], default = `None` Dtype of the generated ALiBi bias. If None, use torch.float32. - bottom_right_alignment: bool, default = `True` + bottom_right_diagonal: bool, default = `True` Whether to align the diagonal of the ALiBi bias to the bottom right corner of the matrix (`True`) or top left (`False`). @@ -1245,12 +1267,12 @@ def get_alibi( 1, 1, 1, max_seqlen_kv ) if actual_seqlens_q is None and actual_seqlens_kv is None: - if bottom_right_alignment: + if bottom_right_diagonal: bias = bias + max_seqlen_kv - max_seqlen_q elif actual_seqlens_q is not None and actual_seqlens_kv is not None: batch_size = actual_seqlens_q.shape[0] bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - if bottom_right_alignment: + if bottom_right_diagonal: bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) else: assert ( @@ -1259,8 +1281,13 @@ def get_alibi( bias = bias.abs().mul(-1) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv - _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment + _alibi_cache["_bottom_right_diagonal"] = bottom_right_diagonal bias_dtype = torch.float32 if bias_dtype is None else bias_dtype + _alibi_cache["_bias_dtype"] = bias_dtype + _alibi_cache["_actual_seqlens_q"], _alibi_cache["_actual_seqlens_kv"] = ( + actual_seqlens_q, + actual_seqlens_kv, + ) _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") _alibi_cache["_alibi_bias_require_update"] = False @@ -4824,6 +4851,7 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -4925,7 +4953,7 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_diagonal=bottom_right_diagonal, ) matmul_result = torch.baddbmm( matmul_result, @@ -6469,6 +6497,7 @@ def forward( attn_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, rng_gen, fused_attention_backend, use_FAv2_bwd, @@ -6557,6 +6586,7 @@ def forward( attn_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, rng_gen, ) if is_output_fp8: @@ -6688,6 +6718,7 @@ def forward( attn_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, rng_gen, ) out_save = out_ret @@ -6733,6 +6764,7 @@ def forward( ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.window_size = window_size + ctx.bottom_right_diagonal = bottom_right_diagonal ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ) @@ -6852,6 +6884,7 @@ def backward(ctx, d_out): ctx.attn_bias_type, ctx.attn_mask_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, ) @@ -6977,6 +7010,7 @@ def backward(ctx, d_out): ctx.attn_bias_type, ctx.attn_mask_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, ) @@ -7010,6 +7044,7 @@ def backward(ctx, d_out): None, None, None, + None, ) # else, return (dqkv, dbias) return ( @@ -7040,6 +7075,7 @@ def backward(ctx, d_out): None, None, None, + None, ) @@ -7132,6 +7168,7 @@ def forward( fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, fast_zero_fill: bool = True, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, @@ -7299,6 +7336,7 @@ def forward( core_attention_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, None, # rng_gen fused_attention_backend, use_FAv2_bwd, @@ -7386,6 +7424,11 @@ class DotProductAttention(TransformerEngineBaseModule): map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can be overridden by :attr:`window_size` in `forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. attention_type: str, default = `self` type of attention, either "`self`" and "`cross`". layer_number: int, default = `None` @@ -7448,6 +7491,7 @@ def __init__( qkv_format: str = "sbhd", attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -7472,6 +7516,7 @@ def __init__( attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type self.window_size = check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -7689,6 +7734,7 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, is_first_microbatch: Optional[bool] = None, @@ -7849,6 +7895,11 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. fast_zero_fill: bool, default = `True` Whether to use the fast path to set output tensors to 0 or not. inference_params: Optional[InferenceParams], default = `None` @@ -7940,6 +7991,15 @@ def forward( if window_size is None: window_size = self.window_size window_size = check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True if self.rng_states_tracker is not None and is_graph_capturing(): assert isinstance( @@ -8111,19 +8171,27 @@ def forward( if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) if core_attention_bias_type == "alibi": assert ( core_attention_bias is None ), "core_attention_bias must be None when core_attention_bias_type is alibi!" if ( _alibi_cache["_num_heads"] != query_layer.shape[-2] - or _alibi_cache["_max_seqlen_q"] != max_seqlen_q - or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment or _alibi_cache["_alibi_slopes"] is None ): _alibi_cache["_alibi_slopes_require_update"] = True + actual_seqlens_q, actual_seqlens_kv = None, None + if "padding" in attn_mask_type: + actual_seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + actual_seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + alibi_dict = {} + alibi_dict["_max_seqlen_q"] = max_seqlen_q + alibi_dict["_max_seqlen_kv"] = max_seqlen_kv + alibi_dict["_bias_dtype"] = query_layer.dtype + alibi_dict["_bottom_right_diagonal"] = bottom_right_diagonal + alibi_dict["_actual_seqlens_q"] = actual_seqlens_q + alibi_dict["_actual_seqlens_kv"] = actual_seqlens_kv + if any(y != _alibi_cache[x] for x, y in alibi_dict.items()): _alibi_cache["_alibi_bias_require_update"] = True core_attention_bias_shape = None @@ -8176,6 +8244,7 @@ def forward( core_attention_bias_requires_grad=( core_attention_bias.requires_grad if core_attention_bias is not None else False ), + bottom_right_diagonal=bottom_right_diagonal, pad_between_seqs=pad_between_seqs, attention_dropout=self.attention_dropout, context_parallel=context_parallel, @@ -8260,7 +8329,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, bias_dtype=query_layer.dtype, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_diagonal=bottom_right_diagonal, ) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -8281,6 +8350,7 @@ def forward( fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, + bottom_right_diagonal=bottom_right_diagonal, fast_zero_fill=fast_zero_fill, cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, @@ -8306,6 +8376,7 @@ def forward( fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, + bottom_right_diagonal=bottom_right_diagonal, fast_zero_fill=fast_zero_fill, cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, @@ -8322,7 +8393,6 @@ def forward( "Attention activation Offloading is only implemented" "with Flash Attention and Fused Attention!" ) - if use_unfused_attention: if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -8339,6 +8409,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, ) return self.unfused_attention( query_layer, @@ -8353,6 +8424,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, ) raise ValueError("No dot product attention support for the provided inputs!") @@ -8409,6 +8481,11 @@ class MultiheadAttention(torch.nn.Module): map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can be overridden by :attr:`window_size` in `forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. num_gqa_groups : int, default = `None` number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -8509,6 +8586,7 @@ def __init__( layer_number: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -8539,6 +8617,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal self.layer_number = layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -8804,6 +8883,7 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, @@ -8873,6 +8953,11 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. cu_seqlens_q: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. @@ -8895,6 +8980,15 @@ def forward( if window_size is None: window_size = self.window_size window_size = check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: @@ -9156,6 +9250,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, fast_zero_fill=fast_zero_fill, inference_params=inference_params, ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 332b4e52eec..bd9020fad27 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -117,6 +117,7 @@ def fused_attn_fwd_qkvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for packed QKV input. @@ -186,6 +187,9 @@ def fused_attn_fwd_qkvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -271,6 +275,7 @@ def fused_attn_fwd_qkvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, cu_seqlens, qkv, qkv_dtype, @@ -324,6 +329,7 @@ def fused_attn_bwd_qkvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed QKV input. @@ -394,6 +400,9 @@ def fused_attn_bwd_qkvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. @@ -444,6 +453,7 @@ def fused_attn_bwd_qkvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens, qkv, @@ -500,6 +510,7 @@ def fused_attn_fwd_kvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for packed KV input. @@ -579,6 +590,9 @@ def fused_attn_fwd_kvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -665,6 +679,7 @@ def fused_attn_fwd_kvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -725,6 +740,7 @@ def fused_attn_bwd_kvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -806,6 +822,9 @@ def fused_attn_bwd_kvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. @@ -859,6 +878,7 @@ def fused_attn_bwd_kvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, @@ -919,6 +939,7 @@ def fused_attn_fwd( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -1004,6 +1025,9 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -1090,6 +1114,7 @@ def fused_attn_fwd( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -1152,6 +1177,7 @@ def fused_attn_bwd( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -1238,6 +1264,9 @@ def fused_attn_bwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. @@ -1293,6 +1322,7 @@ def fused_attn_bwd( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3abcac5bf72..d21747631e8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -48,11 +48,11 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q std::vector fused_attn_fwd_qkvpacked( size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, + const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens, + const at::Tensor QKV, const transformer_engine::DType qkv_type, + const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, + const int descale_QKV_offset, const c10::optional descale_S, + const int descale_S_offset, const c10::optional scale_S, const int scale_S_offset, const c10::optional scale_O, const int scale_O_offset, c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, const int amax_O_offset, const c10::optional Bias, @@ -61,9 +61,10 @@ std::vector fused_attn_fwd_qkvpacked( std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens, + const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, const c10::optional descale_dP, @@ -75,8 +76,8 @@ std::vector fused_attn_fwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, + bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const int descale_QKV_offset, @@ -90,10 +91,11 @@ std::vector fused_attn_fwd_kvpacked( std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor KV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const c10::optional descale_S, @@ -106,9 +108,9 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, + bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor K, const at::Tensor V, + const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, @@ -121,10 +123,11 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const c10::optional descale_S, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 50da91a1a17..3f95319d2f3 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -84,11 +84,11 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe std::vector fused_attn_fwd_qkvpacked( size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, + const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens, + const at::Tensor QKV, const transformer_engine::DType qkv_type, + const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, + const int descale_QKV_offset, const c10::optional descale_S, + const int descale_S_offset, const c10::optional scale_S, const int scale_S_offset, const c10::optional scale_O, const int scale_O_offset, c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, const int amax_O_offset, const c10::optional Bias, @@ -200,7 +200,7 @@ std::vector fused_attn_fwd_qkvpacked( te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -240,7 +240,7 @@ std::vector fused_attn_fwd_qkvpacked( te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -253,9 +253,10 @@ std::vector fused_attn_fwd_qkvpacked( std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens, + const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, const c10::optional descale_dP, @@ -392,11 +393,12 @@ std::vector fused_attn_bwd_qkvpacked( TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), + te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + window_size[0], window_size[1], bottom_right_diagonal, + deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -404,11 +406,12 @@ std::vector fused_attn_bwd_qkvpacked( makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), + te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + window_size[0], window_size[1], bottom_right_diagonal, + deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -421,8 +424,8 @@ std::vector fused_attn_fwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, + bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const int descale_QKV_offset, @@ -538,7 +541,7 @@ std::vector fused_attn_fwd_kvpacked( te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], - workspace.data(), at::cuda::getCurrentCUDAStream()); + bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -579,7 +582,7 @@ std::vector fused_attn_fwd_kvpacked( te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], - workspace.data(), at::cuda::getCurrentCUDAStream()); + bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -592,10 +595,11 @@ std::vector fused_attn_fwd_kvpacked( std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor KV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const c10::optional descale_S, @@ -747,13 +751,13 @@ std::vector fused_attn_bwd_kvpacked( TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), + max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + window_size[0], window_size[1], bottom_right_diagonal, deterministic, workspace.data(), + at::cuda::getCurrentCUDAStream()); // allocate memory for workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -761,13 +765,13 @@ std::vector fused_attn_bwd_kvpacked( makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // execute kernel - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), + max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + window_size[0], window_size[1], bottom_right_diagonal, deterministic, workspace.data(), + at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -780,9 +784,9 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, + bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor K, const at::Tensor V, + const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, @@ -904,8 +908,8 @@ std::vector fused_attn_fwd( te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -946,8 +950,8 @@ std::vector fused_attn_fwd( te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -960,10 +964,11 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional descale_QKV, const c10::optional descale_S, @@ -1199,8 +1204,8 @@ std::vector fused_attn_bwd( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - window_size[0], window_size[1], deterministic, workspace.data(), - at::cuda::getCurrentCUDAStream()); + window_size[0], window_size[1], bottom_right_diagonal, deterministic, + workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -1213,8 +1218,8 @@ std::vector fused_attn_bwd( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - window_size[0], window_size[1], deterministic, workspace.data(), - at::cuda::getCurrentCUDAStream()); + window_size[0], window_size[1], bottom_right_diagonal, deterministic, + workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7c3da9a73f9..9ee5bd572a8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -147,11 +147,21 @@ class TransformerLayer(torch.nn.Module): distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`. Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by :attr:`window_size` in `forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, default = `no_mask` type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = `None` sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -247,8 +257,10 @@ def __init__( kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: bool = None, enc_dec_attn_mask_type: str = "no_mask", enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: bool = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, params_dtype: Optional[torch.dtype] = None, @@ -282,10 +294,12 @@ def __init__( self.self_attn_mask_type = self_attn_mask_type self.window_size = check_set_window_size(self_attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_window_size = check_set_window_size( enc_dec_attn_mask_type, enc_dec_window_size ) + self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad @@ -530,10 +544,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask_type: Optional[str] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -575,6 +591,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = `None` Sliding window size for local attention in encoder. + bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using `layer_type="decoder"`. @@ -591,6 +612,11 @@ def forward( Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = `None` Sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -649,6 +675,24 @@ def forward( if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if self_attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or self_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + if enc_dec_bottom_right_diagonal is None: + enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal + if enc_dec_attn_mask_type in {"causal", "padding_causal"}: + enc_dec_bottom_right_diagonal = False + if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + enc_dec_bottom_right_diagonal = True assert ( self_attn_mask_type in AttnMaskTypes @@ -692,6 +736,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, @@ -723,6 +768,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=enc_dec_bottom_right_diagonal, fast_zero_fill=fast_zero_fill, ) if self.apply_residual_connection_post_layernorm: