From 7bad846ae67d363213827d4a8c91b104d82d3bfd Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Sun, 14 Dec 2025 04:06:14 -0800 Subject: [PATCH 1/4] replace autograd.grad with autograd.backward Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 73 ++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 92826735f98..b4f53ca457c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -62,6 +62,21 @@ def graph_pool_handle(): return _graph_pool_handle() +@contextlib.contextmanager +def _none_grad_context_wrapper(inputs): + """ + Wrapper to set the gradients of the inputs to None, + in case the backward pass makes grad accumulations. + """ + original_input_grads = [] + for input in inputs: + original_input_grads.append(input.grad) + input.grad = None + yield + for input, original_grad in zip(inputs, original_input_grads): + input.grad = original_grad + + @contextlib.contextmanager def _graph_context_wrapper(*args, **kwargs): """Wrapper around `torch.cuda.graph`. @@ -434,13 +449,15 @@ def hook_fn( for hook in hooks: hook.remove() if is_training: - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), - only_inputs=True, - allow_unused=allow_unused_input, - ) + inputs = tuple(i for i in static_input_surface if i.requires_grad) + with _none_grad_context_wrapper(inputs): + torch.autograd.backward( + tuple(o for o in outputs if o.requires_grad), + grad_tensors=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + ) + grad_inputs = tuple(input.grad for input in inputs) # Filter module params that get None grad from grad_inputs and remove them # from static_input_surface. This is to ensure that the backward hooks @@ -454,7 +471,9 @@ def hook_fn( required_grad_input_idx.append(i) module_params_with_grad = [] for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): - if ( + if grad_inputs[grad_inputs_idx] is None and grad_inputs_idx < num_required_grad_sample_args: + assert allow_unused_input, "The input tensor requires grad, but the grad is None after backward pass." + elif ( grad_inputs[grad_inputs_idx] is not None and grad_inputs_idx >= num_required_grad_sample_args ): @@ -606,15 +625,21 @@ def hook_fn( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with _graph_context_wrapper(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, + inputs = tuple( + i for i in static_input_surface if i.requires_grad + ) + with _none_grad_context_wrapper(inputs), _graph_context_wrapper( + bwd_graph, pool=mempool + ): + torch.autograd.backward( + tuple(o for o in static_outputs if o.requires_grad), + grad_tensors=tuple( + o for o in static_grad_outputs if o is not None + ), retain_graph=retain_graph_in_backward, ) + grad_inputs = tuple(input.grad for input in inputs) + # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. @@ -695,15 +720,19 @@ def hook_fn( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with _graph_context_wrapper(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, + inputs = tuple(i for i in static_input_surface if i.requires_grad) + with _none_grad_context_wrapper(inputs), _graph_context_wrapper( + bwd_graph, pool=mempool + ): + torch.autograd.backward( + tuple(o for o in static_outputs if o.requires_grad), + grad_tensors=tuple( + o for o in static_grad_outputs if o is not None + ), retain_graph=retain_graph_in_backward, ) + grad_inputs = tuple(input.grad for input in inputs) + if need_bwd_dw_graph[bwd_idx]: with _graph_context_wrapper(bwd_dw_graph, pool=mempool): for module in visited_te_modules[bwd_idx]: From 4ee6870aa9ac026322705cee39176e78b59a4481 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Dec 2025 02:14:36 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index b4f53ca457c..a8811032b8e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -471,8 +471,14 @@ def hook_fn( required_grad_input_idx.append(i) module_params_with_grad = [] for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): - if grad_inputs[grad_inputs_idx] is None and grad_inputs_idx < num_required_grad_sample_args: - assert allow_unused_input, "The input tensor requires grad, but the grad is None after backward pass." + if ( + grad_inputs[grad_inputs_idx] is None + and grad_inputs_idx < num_required_grad_sample_args + ): + assert allow_unused_input, ( + "The input tensor requires grad, but the grad is None after" + " backward pass." + ) elif ( grad_inputs[grad_inputs_idx] is not None and grad_inputs_idx >= num_required_grad_sample_args @@ -625,17 +631,13 @@ def hook_fn( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - inputs = tuple( - i for i in static_input_surface if i.requires_grad - ) + inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( bwd_graph, pool=mempool ): torch.autograd.backward( tuple(o for o in static_outputs if o.requires_grad), - grad_tensors=tuple( - o for o in static_grad_outputs if o is not None - ), + grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) grad_inputs = tuple(input.grad for input in inputs) @@ -726,9 +728,7 @@ def hook_fn( ): torch.autograd.backward( tuple(o for o in static_outputs if o.requires_grad), - grad_tensors=tuple( - o for o in static_grad_outputs if o is not None - ), + grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) grad_inputs = tuple(input.grad for input in inputs) From 1dfd60b04a500848567ad8db4584bbf3706d414f Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Mon, 15 Dec 2025 22:29:28 -0800 Subject: [PATCH 3/4] get/set graphable rng state Signed-off-by: Robin Zhang --- transformer_engine/pytorch/distributed.py | 41 +++++++++++++++++------ 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index deb9b3ff918..9f589498a4d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool: ) +def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool: + """Returns whether the rng state is a graph safe version.""" + return graph_safe_rng_available() and isinstance(state, torch.Generator) + + def _get_cuda_rng_state( device: Union[int, str, torch.device] = "cuda", clone: bool = False, @@ -340,9 +345,16 @@ def forward( # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) if get_rng_state_tracker is not None: ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() + ctx.graph_safe_rng_state = ( + is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values()))) + if ctx.fwd_cuda_rng_state_tracker + else False + ) + else: + ctx.graph_safe_rng_state = False + ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state) if context_fn is not None: forward_ctx, recompute_ctx = context_fn() @@ -406,13 +418,13 @@ def backward( # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state) if get_rng_state_tracker is not None: bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) @@ -427,7 +439,7 @@ def backward( # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) + _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) @@ -470,12 +482,21 @@ def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable): def cache_rng_states(self, forward=True): """Cache fwd/bwd RNG states in the frame to restore later.""" - rng_states = ( - torch.get_rng_state(), - _get_cuda_rng_state(graph_safe=False), - ) + rng_states = (torch.get_rng_state(),) if self.get_rng_state_tracker is not None: - rng_states += (self.get_rng_state_tracker().get_states(),) + tracker_states = self.get_rng_state_tracker().get_states() + self.graph_safe_rng_state = ( + is_graph_safe_rng_state(next(iter(tracker_states.values()))) + if tracker_states + else False + ) + rng_states += ( + _get_cuda_rng_state(graph_safe=self.graph_safe_rng_state), + tracker_states, + ) + else: + self.graph_safe_rng_state = False + rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),) if forward: self.fwd_rng_states = rng_states @@ -490,7 +511,7 @@ def restore_rng_states(self, forward=True): rng_states = self.bwd_rng_states torch.set_rng_state(rng_states[0]) - _set_cuda_rng_state(rng_states[1], graph_safe=False) + _set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state) if self.get_rng_state_tracker is not None: self.get_rng_state_tracker().set_states(rng_states[2]) From de9392005dbacdf35bf0db9e33d541e4b370bc72 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Mon, 15 Dec 2025 22:35:35 -0800 Subject: [PATCH 4/4] fix lint Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index a8811032b8e..1822c47d8bf 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -69,12 +69,12 @@ def _none_grad_context_wrapper(inputs): in case the backward pass makes grad accumulations. """ original_input_grads = [] - for input in inputs: - original_input_grads.append(input.grad) - input.grad = None + for input_tensor in inputs: + original_input_grads.append(input_tensor.grad) + input_tensor.grad = None yield - for input, original_grad in zip(inputs, original_input_grads): - input.grad = original_grad + for input_tensor, original_grad in zip(inputs, original_input_grads): + input_tensor.grad = original_grad @contextlib.contextmanager