Skip to content

Commit 56037df

Browse files
[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded (#30173)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 5dcd593 commit 56037df

File tree

6 files changed

+65
-33
lines changed

6 files changed

+65
-33
lines changed

tests/v1/cudagraph/test_cudagraph_dispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
161161
assert rt_mode == CUDAGraphMode.NONE
162162
assert key == BatchDescriptor(num_tokens=15)
163163

164-
# 4. Cascade attention should have a fall back mode
164+
# 4. disable_full should have a fall back mode (e.g., cascade attention)
165165
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
166166
rt_mode, key = dispatcher.dispatch(
167-
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
167+
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
168168
)
169169
if "PIECEWISE" in cudagraph_mode_str: # string contains check
170170
assert rt_mode == CUDAGraphMode.PIECEWISE

vllm/forward_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def set_forward_context(
292292
if num_tokens_across_dp is None:
293293
assert ubatch_slices is None
294294
assert num_tokens is not None
295-
_, num_tokens_across_dp = coordinate_batch_across_dp(
295+
_, num_tokens_across_dp, _ = coordinate_batch_across_dp(
296296
num_tokens_unpadded=num_tokens,
297297
parallel_config=vllm_config.parallel_config,
298298
allow_microbatching=False,

vllm/v1/cudagraph_dispatcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def dispatch(
145145
num_tokens: int,
146146
uniform_decode: bool,
147147
has_lora: bool,
148-
use_cascade_attn: bool = False,
148+
disable_full: bool = False,
149149
) -> tuple[CUDAGraphMode, BatchDescriptor]:
150150
"""
151151
Given conditions(e.g.,batch descriptor and if using cascade attention),
@@ -165,7 +165,7 @@ def dispatch(
165165
)
166166
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
167167

168-
if not use_cascade_attn:
168+
if not disable_full:
169169
# check if key exists for full cudagraph
170170
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
171171
return CUDAGraphMode.FULL, batch_desc

vllm/v1/spec_decode/eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ def _pad_batch_across_dp(
12581258
num_tokens_padded: int,
12591259
) -> tuple[int, torch.Tensor]:
12601260
# TODO(Flechman): support DBO ubatching
1261-
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
1261+
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
12621262
num_tokens_unpadded=num_tokens_unpadded,
12631263
parallel_config=self.vllm_config.parallel_config,
12641264
allow_microbatching=False,

vllm/v1/worker/dp_utils.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,18 @@ def _run_ar(
4040
should_dp_pad: bool,
4141
orig_num_tokens_per_ubatch: int,
4242
padded_num_tokens_per_ubatch: int,
43+
cudagraph_mode: int,
4344
parallel_config: ParallelConfig,
4445
) -> torch.Tensor:
4546
dp_size = parallel_config.data_parallel_size
4647
dp_rank = parallel_config.data_parallel_rank
4748
device, group = _get_device_and_group(parallel_config)
48-
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
49+
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
4950
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
5051
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
5152
tensor[2][dp_rank] = 1 if should_ubatch else 0
5253
tensor[3][dp_rank] = 1 if should_dp_pad else 0
54+
tensor[4][dp_rank] = cudagraph_mode
5355
dist.all_reduce(tensor, group=group)
5456
return tensor
5557

@@ -89,13 +91,23 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
8991
return num_tokens_across_dp.cpu()
9092

9193

94+
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
95+
"""
96+
Synchronize cudagraph_mode across DP ranks by taking the minimum.
97+
If any rank has NONE (0), all ranks use NONE.
98+
This ensures all ranks send consistent values (all padded or all unpadded).
99+
"""
100+
return int(tensor[4, :].min().item())
101+
102+
92103
def _synchronize_dp_ranks(
93104
num_tokens_unpadded: int,
94105
num_tokens_padded: int,
95106
should_attempt_ubatching: bool,
96107
should_attempt_dp_padding: bool,
108+
cudagraph_mode: int,
97109
parallel_config: ParallelConfig,
98-
) -> tuple[bool, torch.Tensor | None]:
110+
) -> tuple[bool, torch.Tensor | None, int]:
99111
"""
100112
1. Decides if each DP rank is going to microbatch. Either all ranks
101113
run with microbatching or none of them do.
@@ -104,10 +116,13 @@ def _synchronize_dp_ranks(
104116
When running microbatched or if should_attempt_dp_padding is True, all
105117
ranks will be padded out so that the run with the same number of tokens
106118
119+
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
120+
107121
Returns: tuple[
108122
should_ubatch: Are all DP ranks going to microbatch
109123
num_tokens_after_padding: A tensor containing the total number of
110124
tokens per-microbatch for each DP rank including any DP padding.
125+
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
111126
]
112127
113128
"""
@@ -121,6 +136,7 @@ def _synchronize_dp_ranks(
121136
should_dp_pad=should_attempt_dp_padding,
122137
orig_num_tokens_per_ubatch=num_tokens_unpadded,
123138
padded_num_tokens_per_ubatch=num_tokens_padded,
139+
cudagraph_mode=cudagraph_mode,
124140
parallel_config=parallel_config,
125141
)
126142

@@ -148,7 +164,10 @@ def _synchronize_dp_ranks(
148164
should_dp_pad,
149165
)
150166

151-
return should_ubatch, num_tokens_after_padding
167+
# Synchronize cudagraph_mode across ranks (take min)
168+
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
169+
170+
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
152171

153172

154173
def coordinate_batch_across_dp(
@@ -159,7 +178,8 @@ def coordinate_batch_across_dp(
159178
num_tokens_padded: int | None = None,
160179
uniform_decode: bool | None = None,
161180
num_scheduled_tokens_per_request: np.ndarray | None = None,
162-
) -> tuple[bool, torch.Tensor | None]:
181+
cudagraph_mode: int = 0,
182+
) -> tuple[bool, torch.Tensor | None, int]:
163183
"""
164184
Coordinates amongst all DP ranks to determine if and how the full batch
165185
should be split into microbatches.
@@ -175,6 +195,7 @@ def coordinate_batch_across_dp(
175195
only contains single token decodes
176196
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
177197
number of tokens per request.
198+
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
178199
179200
Returns: tuple[
180201
ubatch_slices: if this is set then all DP ranks have agreed to
@@ -183,12 +204,13 @@ def coordinate_batch_across_dp(
183204
tokens per-microbatch for each DP rank including padding. Will be
184205
padded up to the max value across all DP ranks when allow_dp_padding
185206
is True.
207+
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
186208
]
187209
188210
"""
189211
if parallel_config.data_parallel_size == 1:
190212
# Early exit.
191-
return False, None
213+
return False, None, cudagraph_mode
192214

193215
# If the caller has explicitly enabled microbatching.
194216
should_attempt_ubatching = False
@@ -204,12 +226,15 @@ def coordinate_batch_across_dp(
204226
if num_tokens_padded is None:
205227
num_tokens_padded = num_tokens_unpadded
206228

207-
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
208-
num_tokens_unpadded,
209-
num_tokens_padded,
210-
should_attempt_ubatching,
211-
allow_dp_padding,
212-
parallel_config,
229+
(should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
230+
_synchronize_dp_ranks(
231+
num_tokens_unpadded,
232+
num_tokens_padded,
233+
should_attempt_ubatching,
234+
allow_dp_padding,
235+
cudagraph_mode,
236+
parallel_config,
237+
)
213238
)
214239

215-
return (should_ubatch, num_tokens_after_padding)
240+
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,17 +2788,19 @@ def _determine_batch_execution_and_padding(
27882788
)
27892789

27902790
dispatch_cudagraph = (
2791-
lambda num_tokens: self.cudagraph_dispatcher.dispatch(
2791+
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
27922792
num_tokens=num_tokens,
27932793
has_lora=has_lora,
2794-
use_cascade_attn=use_cascade_attn,
27952794
uniform_decode=uniform_decode,
2795+
disable_full=disable_full,
27962796
)
27972797
if not force_eager
27982798
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
27992799
)
28002800

2801-
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
2801+
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
2802+
num_tokens_padded, use_cascade_attn
2803+
)
28022804
num_tokens_padded = batch_descriptor.num_tokens
28032805

28042806
# Extra coordination when running data-parallel since we need to coordinate
@@ -2813,23 +2815,28 @@ def _determine_batch_execution_and_padding(
28132815
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
28142816
)
28152817

2816-
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
2817-
num_tokens_unpadded=num_tokens,
2818-
parallel_config=self.parallel_config,
2819-
allow_microbatching=allow_microbatching,
2820-
allow_dp_padding=allow_dp_padding,
2821-
num_tokens_padded=num_tokens_padded,
2822-
uniform_decode=uniform_decode,
2823-
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
2818+
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
2819+
coordinate_batch_across_dp(
2820+
num_tokens_unpadded=num_tokens,
2821+
parallel_config=self.parallel_config,
2822+
allow_microbatching=allow_microbatching,
2823+
allow_dp_padding=allow_dp_padding,
2824+
num_tokens_padded=num_tokens_padded,
2825+
uniform_decode=uniform_decode,
2826+
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
2827+
cudagraph_mode=cudagraph_mode.value,
2828+
)
28242829
)
28252830

2826-
# Extract DP padding if there is any
2831+
# Extract DP-synced values
28272832
if num_tokens_across_dp is not None:
28282833
dp_rank = self.parallel_config.data_parallel_rank
28292834
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
2830-
2831-
# Re-dispatch with DP padding
2832-
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
2835+
# Re-dispatch with DP padding so we have the correct batch_descriptor
2836+
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
2837+
num_tokens_padded,
2838+
disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
2839+
)
28332840
# Assert to make sure the agreed upon token count is correct otherwise
28342841
# num_tokens_across_dp will no-longer be valid
28352842
assert batch_descriptor.num_tokens == num_tokens_padded

0 commit comments

Comments
 (0)