Skip to content

Commit 75beab1

Browse files
YangKai0616vasqu
andauthored
Fixed paged|FA2 kernel loading logic and UT. (#42547)
* Fixed UT and kernel loading logic. * Revision based on comments * Simplify code * make style * simplify CB part * retrigger ci --------- Co-authored-by: vasqu <antonprogamer@gmail.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
1 parent d3ee06b commit 75beab1

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -763,15 +763,9 @@ def __init__(
763763
num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
764764
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
765765
"""
766+
# Reloade paged version if necessary
766767
if "paged|" not in model.config._attn_implementation:
767-
attn_implementation = f"paged|{model.config._attn_implementation}"
768-
model.config._attn_implementation = attn_implementation
769-
770-
# lazy loading flash attention including kernel variations
771-
if "flash" in attn_implementation:
772-
from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention
773-
774-
lazy_import_paged_flash_attention(attn_implementation)
768+
model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
775769

776770
self.model = model.eval()
777771
generation_config = model.generation_config if generation_config is None else generation_config

src/transformers/modeling_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
verify_tp_plan,
8686
)
8787
from .loss.loss_utils import LOSS_MAPPING
88-
from .modeling_flash_attention_utils import lazy_import_flash_attention
88+
from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
8989
from .pytorch_utils import id_tensor_storage
9090
from .quantizers import HfQuantizer
9191
from .quantizers.auto import get_hf_quantizer
@@ -1763,9 +1763,12 @@ def _check_and_adjust_attn_implementation(
17631763
"""
17641764
applicable_attn_implementation = attn_implementation
17651765

1766+
is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
1767+
17661768
# If FA not installed, do not fail but use kernels instead
17671769
requested_original_flash_attn = attn_implementation is not None and (
1768-
attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"
1770+
attn_implementation.removeprefix("paged|") == "flash_attention_2"
1771+
or attn_implementation.removeprefix("paged|") == "flash_attention_3"
17691772
)
17701773
if (
17711774
requested_original_flash_attn
@@ -1783,10 +1786,16 @@ def _check_and_adjust_attn_implementation(
17831786
else:
17841787
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
17851788

1789+
if is_paged:
1790+
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
1791+
17861792
if is_kernel(applicable_attn_implementation):
17871793
try:
17881794
# preload flash attention here to allow compile with fullgraph
1789-
lazy_import_flash_attention(applicable_attn_implementation)
1795+
if is_paged:
1796+
lazy_import_paged_flash_attention(applicable_attn_implementation)
1797+
else:
1798+
lazy_import_flash_attention(applicable_attn_implementation)
17901799

17911800
# log that we used kernel fallback if successful
17921801
if requested_original_flash_attn:

tests/generation/test_continuous_batching.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
2323
from transformers.testing_utils import (
2424
Expectations,
25+
require_deterministic_for_xpu,
2526
require_kernels,
2627
require_read_token,
2728
require_torch_accelerator,
@@ -137,6 +138,7 @@ def test_attention_mask(
137138
f"Actual mask:\n{str_mask}"
138139
)
139140

141+
@require_deterministic_for_xpu
140142
def _continuous_batching_parity(
141143
self, model_id: str, attn_implementation: str, expected_outputs: dict[str, str]
142144
) -> None:

0 commit comments

Comments
 (0)