@@ -62,6 +62,21 @@ def graph_pool_handle():
6262 return _graph_pool_handle ()
6363
6464
65+ @contextlib .contextmanager
66+ def _none_grad_context_wrapper (inputs ):
67+ """
68+ Wrapper to set the gradients of the inputs to None,
69+ in case the backward pass makes grad accumulations.
70+ """
71+ original_input_grads = []
72+ for input_tensor in inputs :
73+ original_input_grads .append (input_tensor .grad )
74+ input_tensor .grad = None
75+ yield
76+ for input_tensor , original_grad in zip (inputs , original_input_grads ):
77+ input_tensor .grad = original_grad
78+
79+
6580@contextlib .contextmanager
6681def _graph_context_wrapper (* args , ** kwargs ):
6782 """Wrapper around `torch.cuda.graph`.
@@ -434,13 +449,15 @@ def hook_fn(
434449 for hook in hooks :
435450 hook .remove ()
436451 if is_training :
437- grad_inputs = torch .autograd .grad (
438- outputs = tuple (o for o in outputs if o .requires_grad ),
439- inputs = tuple (i for i in static_input_surface if i .requires_grad ),
440- grad_outputs = tuple (torch .empty_like (o ) for o in outputs if o .requires_grad ),
441- only_inputs = True ,
442- allow_unused = allow_unused_input ,
443- )
452+ inputs = tuple (i for i in static_input_surface if i .requires_grad )
453+ with _none_grad_context_wrapper (inputs ):
454+ torch .autograd .backward (
455+ tuple (o for o in outputs if o .requires_grad ),
456+ grad_tensors = tuple (
457+ torch .empty_like (o ) for o in outputs if o .requires_grad
458+ ),
459+ )
460+ grad_inputs = tuple (input .grad for input in inputs )
444461
445462 # Filter module params that get None grad from grad_inputs and remove them
446463 # from static_input_surface. This is to ensure that the backward hooks
@@ -455,6 +472,14 @@ def hook_fn(
455472 module_params_with_grad = []
456473 for grad_inputs_idx , inputs_idx in enumerate (required_grad_input_idx ):
457474 if (
475+ grad_inputs [grad_inputs_idx ] is None
476+ and grad_inputs_idx < num_required_grad_sample_args
477+ ):
478+ assert allow_unused_input , (
479+ "The input tensor requires grad, but the grad is None after"
480+ " backward pass."
481+ )
482+ elif (
458483 grad_inputs [grad_inputs_idx ] is not None
459484 and grad_inputs_idx >= num_required_grad_sample_args
460485 ):
@@ -606,15 +631,17 @@ def hook_fn(
606631 torch .empty_like (o ) if o .requires_grad else None for o in static_outputs
607632 )
608633 if is_training :
609- with _graph_context_wrapper ( bwd_graph , pool = mempool ):
610- grad_inputs = torch . autograd . grad (
611- outputs = tuple ( o for o in static_outputs if o . requires_grad ),
612- inputs = tuple ( i for i in static_input_surface if i . requires_grad ),
613- grad_outputs = tuple ( o for o in static_grad_outputs if o is not None ),
614- only_inputs = True ,
615- allow_unused = allow_unused_input ,
634+ inputs = tuple ( i for i in static_input_surface if i . requires_grad )
635+ with _none_grad_context_wrapper ( inputs ), _graph_context_wrapper (
636+ bwd_graph , pool = mempool
637+ ):
638+ torch . autograd . backward (
639+ tuple ( o for o in static_outputs if o . requires_grad ) ,
640+ grad_tensors = tuple ( o for o in static_grad_outputs if o is not None ) ,
616641 retain_graph = retain_graph_in_backward ,
617642 )
643+ grad_inputs = tuple (input .grad for input in inputs )
644+
618645 # Constructs a tuple suitable for returning from Graphed.backward:
619646 # Pads out the actually-needed grads with Nones in gradient slots for inputs
620647 # that don't require grad. I couldn't think of a one-liner for this pattern.
@@ -695,15 +722,17 @@ def hook_fn(
695722 torch .empty_like (o ) if o .requires_grad else None for o in static_outputs
696723 )
697724 if is_training :
698- with _graph_context_wrapper ( bwd_graph , pool = mempool ):
699- grad_inputs = torch . autograd . grad (
700- outputs = tuple ( o for o in static_outputs if o . requires_grad ),
701- inputs = tuple ( i for i in static_input_surface if i . requires_grad ),
702- grad_outputs = tuple ( o for o in static_grad_outputs if o is not None ),
703- only_inputs = True ,
704- allow_unused = allow_unused_input ,
725+ inputs = tuple ( i for i in static_input_surface if i . requires_grad )
726+ with _none_grad_context_wrapper ( inputs ), _graph_context_wrapper (
727+ bwd_graph , pool = mempool
728+ ):
729+ torch . autograd . backward (
730+ tuple ( o for o in static_outputs if o . requires_grad ) ,
731+ grad_tensors = tuple ( o for o in static_grad_outputs if o is not None ) ,
705732 retain_graph = retain_graph_in_backward ,
706733 )
734+ grad_inputs = tuple (input .grad for input in inputs )
735+
707736 if need_bwd_dw_graph [bwd_idx ]:
708737 with _graph_context_wrapper (bwd_dw_graph , pool = mempool ):
709738 for module in visited_te_modules [bwd_idx ]:
0 commit comments