diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index deb9b3ff918..9f589498a4d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool: ) +def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool: + """Returns whether the rng state is a graph safe version.""" + return graph_safe_rng_available() and isinstance(state, torch.Generator) + + def _get_cuda_rng_state( device: Union[int, str, torch.device] = "cuda", clone: bool = False, @@ -340,9 +345,16 @@ def forward( # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) if get_rng_state_tracker is not None: ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() + ctx.graph_safe_rng_state = ( + is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values()))) + if ctx.fwd_cuda_rng_state_tracker + else False + ) + else: + ctx.graph_safe_rng_state = False + ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state) if context_fn is not None: forward_ctx, recompute_ctx = context_fn() @@ -406,13 +418,13 @@ def backward( # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state) if get_rng_state_tracker is not None: bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) @@ -427,7 +439,7 @@ def backward( # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) + _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) @@ -470,12 +482,21 @@ def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable): def cache_rng_states(self, forward=True): """Cache fwd/bwd RNG states in the frame to restore later.""" - rng_states = ( - torch.get_rng_state(), - _get_cuda_rng_state(graph_safe=False), - ) + rng_states = (torch.get_rng_state(),) if self.get_rng_state_tracker is not None: - rng_states += (self.get_rng_state_tracker().get_states(),) + tracker_states = self.get_rng_state_tracker().get_states() + self.graph_safe_rng_state = ( + is_graph_safe_rng_state(next(iter(tracker_states.values()))) + if tracker_states + else False + ) + rng_states += ( + _get_cuda_rng_state(graph_safe=self.graph_safe_rng_state), + tracker_states, + ) + else: + self.graph_safe_rng_state = False + rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),) if forward: self.fwd_rng_states = rng_states @@ -490,7 +511,7 @@ def restore_rng_states(self, forward=True): rng_states = self.bwd_rng_states torch.set_rng_state(rng_states[0]) - _set_cuda_rng_state(rng_states[1], graph_safe=False) + _set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state) if self.get_rng_state_tracker is not None: self.get_rng_state_tracker().set_states(rng_states[2]) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 92826735f98..1822c47d8bf 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -62,6 +62,21 @@ def graph_pool_handle(): return _graph_pool_handle() +@contextlib.contextmanager +def _none_grad_context_wrapper(inputs): + """ + Wrapper to set the gradients of the inputs to None, + in case the backward pass makes grad accumulations. + """ + original_input_grads = [] + for input_tensor in inputs: + original_input_grads.append(input_tensor.grad) + input_tensor.grad = None + yield + for input_tensor, original_grad in zip(inputs, original_input_grads): + input_tensor.grad = original_grad + + @contextlib.contextmanager def _graph_context_wrapper(*args, **kwargs): """Wrapper around `torch.cuda.graph`. @@ -434,13 +449,15 @@ def hook_fn( for hook in hooks: hook.remove() if is_training: - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), - only_inputs=True, - allow_unused=allow_unused_input, - ) + inputs = tuple(i for i in static_input_surface if i.requires_grad) + with _none_grad_context_wrapper(inputs): + torch.autograd.backward( + tuple(o for o in outputs if o.requires_grad), + grad_tensors=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + ) + grad_inputs = tuple(input.grad for input in inputs) # Filter module params that get None grad from grad_inputs and remove them # from static_input_surface. This is to ensure that the backward hooks @@ -455,6 +472,14 @@ def hook_fn( module_params_with_grad = [] for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): if ( + grad_inputs[grad_inputs_idx] is None + and grad_inputs_idx < num_required_grad_sample_args + ): + assert allow_unused_input, ( + "The input tensor requires grad, but the grad is None after" + " backward pass." + ) + elif ( grad_inputs[grad_inputs_idx] is not None and grad_inputs_idx >= num_required_grad_sample_args ): @@ -606,15 +631,17 @@ def hook_fn( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with _graph_context_wrapper(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, + inputs = tuple(i for i in static_input_surface if i.requires_grad) + with _none_grad_context_wrapper(inputs), _graph_context_wrapper( + bwd_graph, pool=mempool + ): + torch.autograd.backward( + tuple(o for o in static_outputs if o.requires_grad), + grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) + grad_inputs = tuple(input.grad for input in inputs) + # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. @@ -695,15 +722,17 @@ def hook_fn( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with _graph_context_wrapper(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, + inputs = tuple(i for i in static_input_surface if i.requires_grad) + with _none_grad_context_wrapper(inputs), _graph_context_wrapper( + bwd_graph, pool=mempool + ): + torch.autograd.backward( + tuple(o for o in static_outputs if o.requires_grad), + grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) + grad_inputs = tuple(input.grad for input in inputs) + if need_bwd_dw_graph[bwd_idx]: with _graph_context_wrapper(bwd_dw_graph, pool=mempool): for module in visited_te_modules[bwd_idx]: