Skip to content

Commit 324be33

Browse files
buptzybksivaman
andauthored
[PyTorch] Support cudagraph recomputation (#2518)
* replace autograd.grad with autograd.backward Signed-off-by: Robin Zhang <robinz@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * get/set graphable rng state Signed-off-by: Robin Zhang <robinz@nvidia.com> * fix lint Signed-off-by: Robin Zhang <robinz@nvidia.com> --------- Signed-off-by: Robin Zhang <robinz@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 697b52c commit 324be33

File tree

2 files changed

+81
-31
lines changed

2 files changed

+81
-31
lines changed

transformer_engine/pytorch/distributed.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool:
9090
)
9191

9292

93+
def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool:
94+
"""Returns whether the rng state is a graph safe version."""
95+
return graph_safe_rng_available() and isinstance(state, torch.Generator)
96+
97+
9398
def _get_cuda_rng_state(
9499
device: Union[int, str, torch.device] = "cuda",
95100
clone: bool = False,
@@ -340,9 +345,16 @@ def forward(
340345

341346
# Copy the rng states.
342347
ctx.fwd_cpu_rng_state = torch.get_rng_state()
343-
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
344348
if get_rng_state_tracker is not None:
345349
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
350+
ctx.graph_safe_rng_state = (
351+
is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values())))
352+
if ctx.fwd_cuda_rng_state_tracker
353+
else False
354+
)
355+
else:
356+
ctx.graph_safe_rng_state = False
357+
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
346358

347359
if context_fn is not None:
348360
forward_ctx, recompute_ctx = context_fn()
@@ -406,13 +418,13 @@ def backward(
406418

407419
# Store the current states.
408420
bwd_cpu_rng_state = torch.get_rng_state()
409-
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
421+
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
410422
if get_rng_state_tracker is not None:
411423
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
412424

413425
# Set the states to what it used to be before the forward pass.
414426
torch.set_rng_state(ctx.fwd_cpu_rng_state)
415-
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
427+
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
416428
if get_rng_state_tracker is not None:
417429
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
418430

@@ -427,7 +439,7 @@ def backward(
427439

428440
# Set the states back to what it was at the start of this function.
429441
torch.set_rng_state(bwd_cpu_rng_state)
430-
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
442+
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
431443
if get_rng_state_tracker is not None:
432444
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
433445

@@ -470,12 +482,21 @@ def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable):
470482

471483
def cache_rng_states(self, forward=True):
472484
"""Cache fwd/bwd RNG states in the frame to restore later."""
473-
rng_states = (
474-
torch.get_rng_state(),
475-
_get_cuda_rng_state(graph_safe=False),
476-
)
485+
rng_states = (torch.get_rng_state(),)
477486
if self.get_rng_state_tracker is not None:
478-
rng_states += (self.get_rng_state_tracker().get_states(),)
487+
tracker_states = self.get_rng_state_tracker().get_states()
488+
self.graph_safe_rng_state = (
489+
is_graph_safe_rng_state(next(iter(tracker_states.values())))
490+
if tracker_states
491+
else False
492+
)
493+
rng_states += (
494+
_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),
495+
tracker_states,
496+
)
497+
else:
498+
self.graph_safe_rng_state = False
499+
rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),)
479500

480501
if forward:
481502
self.fwd_rng_states = rng_states
@@ -490,7 +511,7 @@ def restore_rng_states(self, forward=True):
490511
rng_states = self.bwd_rng_states
491512

492513
torch.set_rng_state(rng_states[0])
493-
_set_cuda_rng_state(rng_states[1], graph_safe=False)
514+
_set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state)
494515
if self.get_rng_state_tracker is not None:
495516
self.get_rng_state_tracker().set_states(rng_states[2])
496517

transformer_engine/pytorch/graph.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ def graph_pool_handle():
6262
return _graph_pool_handle()
6363

6464

65+
@contextlib.contextmanager
66+
def _none_grad_context_wrapper(inputs):
67+
"""
68+
Wrapper to set the gradients of the inputs to None,
69+
in case the backward pass makes grad accumulations.
70+
"""
71+
original_input_grads = []
72+
for input_tensor in inputs:
73+
original_input_grads.append(input_tensor.grad)
74+
input_tensor.grad = None
75+
yield
76+
for input_tensor, original_grad in zip(inputs, original_input_grads):
77+
input_tensor.grad = original_grad
78+
79+
6580
@contextlib.contextmanager
6681
def _graph_context_wrapper(*args, **kwargs):
6782
"""Wrapper around `torch.cuda.graph`.
@@ -434,13 +449,15 @@ def hook_fn(
434449
for hook in hooks:
435450
hook.remove()
436451
if is_training:
437-
grad_inputs = torch.autograd.grad(
438-
outputs=tuple(o for o in outputs if o.requires_grad),
439-
inputs=tuple(i for i in static_input_surface if i.requires_grad),
440-
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
441-
only_inputs=True,
442-
allow_unused=allow_unused_input,
443-
)
452+
inputs = tuple(i for i in static_input_surface if i.requires_grad)
453+
with _none_grad_context_wrapper(inputs):
454+
torch.autograd.backward(
455+
tuple(o for o in outputs if o.requires_grad),
456+
grad_tensors=tuple(
457+
torch.empty_like(o) for o in outputs if o.requires_grad
458+
),
459+
)
460+
grad_inputs = tuple(input.grad for input in inputs)
444461

445462
# Filter module params that get None grad from grad_inputs and remove them
446463
# from static_input_surface. This is to ensure that the backward hooks
@@ -455,6 +472,14 @@ def hook_fn(
455472
module_params_with_grad = []
456473
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
457474
if (
475+
grad_inputs[grad_inputs_idx] is None
476+
and grad_inputs_idx < num_required_grad_sample_args
477+
):
478+
assert allow_unused_input, (
479+
"The input tensor requires grad, but the grad is None after"
480+
" backward pass."
481+
)
482+
elif (
458483
grad_inputs[grad_inputs_idx] is not None
459484
and grad_inputs_idx >= num_required_grad_sample_args
460485
):
@@ -606,15 +631,17 @@ def hook_fn(
606631
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
607632
)
608633
if is_training:
609-
with _graph_context_wrapper(bwd_graph, pool=mempool):
610-
grad_inputs = torch.autograd.grad(
611-
outputs=tuple(o for o in static_outputs if o.requires_grad),
612-
inputs=tuple(i for i in static_input_surface if i.requires_grad),
613-
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
614-
only_inputs=True,
615-
allow_unused=allow_unused_input,
634+
inputs = tuple(i for i in static_input_surface if i.requires_grad)
635+
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
636+
bwd_graph, pool=mempool
637+
):
638+
torch.autograd.backward(
639+
tuple(o for o in static_outputs if o.requires_grad),
640+
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
616641
retain_graph=retain_graph_in_backward,
617642
)
643+
grad_inputs = tuple(input.grad for input in inputs)
644+
618645
# Constructs a tuple suitable for returning from Graphed.backward:
619646
# Pads out the actually-needed grads with Nones in gradient slots for inputs
620647
# that don't require grad. I couldn't think of a one-liner for this pattern.
@@ -695,15 +722,17 @@ def hook_fn(
695722
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
696723
)
697724
if is_training:
698-
with _graph_context_wrapper(bwd_graph, pool=mempool):
699-
grad_inputs = torch.autograd.grad(
700-
outputs=tuple(o for o in static_outputs if o.requires_grad),
701-
inputs=tuple(i for i in static_input_surface if i.requires_grad),
702-
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
703-
only_inputs=True,
704-
allow_unused=allow_unused_input,
725+
inputs = tuple(i for i in static_input_surface if i.requires_grad)
726+
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
727+
bwd_graph, pool=mempool
728+
):
729+
torch.autograd.backward(
730+
tuple(o for o in static_outputs if o.requires_grad),
731+
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
705732
retain_graph=retain_graph_in_backward,
706733
)
734+
grad_inputs = tuple(input.grad for input in inputs)
735+
707736
if need_bwd_dw_graph[bwd_idx]:
708737
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
709738
for module in visited_te_modules[bwd_idx]:

0 commit comments

Comments
 (0)