@@ -40,16 +40,18 @@ def _run_ar(
4040 should_dp_pad : bool ,
4141 orig_num_tokens_per_ubatch : int ,
4242 padded_num_tokens_per_ubatch : int ,
43+ cudagraph_mode : int ,
4344 parallel_config : ParallelConfig ,
4445) -> torch .Tensor :
4546 dp_size = parallel_config .data_parallel_size
4647 dp_rank = parallel_config .data_parallel_rank
4748 device , group = _get_device_and_group (parallel_config )
48- tensor = torch .zeros (4 , dp_size , device = device , dtype = torch .int32 )
49+ tensor = torch .zeros (5 , dp_size , device = device , dtype = torch .int32 )
4950 tensor [0 ][dp_rank ] = orig_num_tokens_per_ubatch
5051 tensor [1 ][dp_rank ] = padded_num_tokens_per_ubatch
5152 tensor [2 ][dp_rank ] = 1 if should_ubatch else 0
5253 tensor [3 ][dp_rank ] = 1 if should_dp_pad else 0
54+ tensor [4 ][dp_rank ] = cudagraph_mode
5355 dist .all_reduce (tensor , group = group )
5456 return tensor
5557
@@ -89,13 +91,23 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
8991 return num_tokens_across_dp .cpu ()
9092
9193
94+ def _post_process_cudagraph_mode (tensor : torch .Tensor ) -> int :
95+ """
96+ Synchronize cudagraph_mode across DP ranks by taking the minimum.
97+ If any rank has NONE (0), all ranks use NONE.
98+ This ensures all ranks send consistent values (all padded or all unpadded).
99+ """
100+ return int (tensor [4 , :].min ().item ())
101+
102+
92103def _synchronize_dp_ranks (
93104 num_tokens_unpadded : int ,
94105 num_tokens_padded : int ,
95106 should_attempt_ubatching : bool ,
96107 should_attempt_dp_padding : bool ,
108+ cudagraph_mode : int ,
97109 parallel_config : ParallelConfig ,
98- ) -> tuple [bool , torch .Tensor | None ]:
110+ ) -> tuple [bool , torch .Tensor | None , int ]:
99111 """
100112 1. Decides if each DP rank is going to microbatch. Either all ranks
101113 run with microbatching or none of them do.
@@ -104,10 +116,13 @@ def _synchronize_dp_ranks(
104116 When running microbatched or if should_attempt_dp_padding is True, all
105117 ranks will be padded out so that the run with the same number of tokens
106118
119+ 3. Synchronizes cudagraph_mode across ranks by taking the minimum.
120+
107121 Returns: tuple[
108122 should_ubatch: Are all DP ranks going to microbatch
109123 num_tokens_after_padding: A tensor containing the total number of
110124 tokens per-microbatch for each DP rank including any DP padding.
125+ synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
111126 ]
112127
113128 """
@@ -121,6 +136,7 @@ def _synchronize_dp_ranks(
121136 should_dp_pad = should_attempt_dp_padding ,
122137 orig_num_tokens_per_ubatch = num_tokens_unpadded ,
123138 padded_num_tokens_per_ubatch = num_tokens_padded ,
139+ cudagraph_mode = cudagraph_mode ,
124140 parallel_config = parallel_config ,
125141 )
126142
@@ -148,7 +164,10 @@ def _synchronize_dp_ranks(
148164 should_dp_pad ,
149165 )
150166
151- return should_ubatch , num_tokens_after_padding
167+ # Synchronize cudagraph_mode across ranks (take min)
168+ synced_cudagraph_mode = _post_process_cudagraph_mode (tensor )
169+
170+ return should_ubatch , num_tokens_after_padding , synced_cudagraph_mode
152171
153172
154173def coordinate_batch_across_dp (
@@ -159,7 +178,8 @@ def coordinate_batch_across_dp(
159178 num_tokens_padded : int | None = None ,
160179 uniform_decode : bool | None = None ,
161180 num_scheduled_tokens_per_request : np .ndarray | None = None ,
162- ) -> tuple [bool , torch .Tensor | None ]:
181+ cudagraph_mode : int = 0 ,
182+ ) -> tuple [bool , torch .Tensor | None , int ]:
163183 """
164184 Coordinates amongst all DP ranks to determine if and how the full batch
165185 should be split into microbatches.
@@ -175,6 +195,7 @@ def coordinate_batch_across_dp(
175195 only contains single token decodes
176196 num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
177197 number of tokens per request.
198+ cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
178199
179200 Returns: tuple[
180201 ubatch_slices: if this is set then all DP ranks have agreed to
@@ -183,12 +204,13 @@ def coordinate_batch_across_dp(
183204 tokens per-microbatch for each DP rank including padding. Will be
184205 padded up to the max value across all DP ranks when allow_dp_padding
185206 is True.
207+ synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
186208 ]
187209
188210 """
189211 if parallel_config .data_parallel_size == 1 :
190212 # Early exit.
191- return False , None
213+ return False , None , cudagraph_mode
192214
193215 # If the caller has explicitly enabled microbatching.
194216 should_attempt_ubatching = False
@@ -204,12 +226,15 @@ def coordinate_batch_across_dp(
204226 if num_tokens_padded is None :
205227 num_tokens_padded = num_tokens_unpadded
206228
207- (should_ubatch , num_tokens_after_padding ) = _synchronize_dp_ranks (
208- num_tokens_unpadded ,
209- num_tokens_padded ,
210- should_attempt_ubatching ,
211- allow_dp_padding ,
212- parallel_config ,
229+ (should_ubatch , num_tokens_after_padding , synced_cudagraph_mode ) = (
230+ _synchronize_dp_ranks (
231+ num_tokens_unpadded ,
232+ num_tokens_padded ,
233+ should_attempt_ubatching ,
234+ allow_dp_padding ,
235+ cudagraph_mode ,
236+ parallel_config ,
237+ )
213238 )
214239
215- return (should_ubatch , num_tokens_after_padding )
240+ return (should_ubatch , num_tokens_after_padding , synced_cudagraph_mode )
0 commit comments