diff --git a/recipes_source/distributed_async_checkpoint_recipe.rst b/recipes_source/distributed_async_checkpoint_recipe.rst index e959883a25b..e89644a294a 100644 --- a/recipes_source/distributed_async_checkpoint_recipe.rst +++ b/recipes_source/distributed_async_checkpoint_recipe.rst @@ -1,7 +1,7 @@ Asynchronous Saving with Distributed Checkpoint (DCP) ===================================================== -**Author:** `Lucas Pasqualin `__, `Iris Zhang `__, `Rodrigo Kumpera `__, `Chien-Chin Huang `__ +**Author:** `Lucas Pasqualin `__, `Iris Zhang `__, `Rodrigo Kumpera `__, `Chien-Chin Huang `__, `Yunsheng Ni `__ Checkpointing is often a bottle-neck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow. One excellent strategy for offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example @@ -279,6 +279,154 @@ checkpoint requests users can take advantage of direct memory access to speed up ) +Fully Asynchronous Staging with DefaultStager +-------------------------------------------- + +.. versionadded:: 2.9 + The ``async_stager`` argument and ``DefaultStager`` class were introduced in PyTorch 2.9. + +While ``async_save`` handles the disk write asynchronously, the process of copying data from GPU to CPU (known as "staging") typically happens on the main thread. Even with Pinned Memory, this Device-to-Host (D2H) copy can block the training loop for large models. + +To achieve maximum overlap between computation and checkpointing, we can use the ``DefaultStager``. This component offloads the state dictionary creation and the D2H copy to a background thread. + +**Timeline Comparison:** + +* **Standard async_save:** ``[GPU Compute] -> [CPU Copy (Blocking)] -> [Disk Write (Async)]`` +* **With AsyncStager:** ``[GPU Compute] || [CPU Copy (Async)] -> [Disk Write (Async)]`` + +.. note:: + Using ``AsyncStager`` introduces a background thread that consumes CPU resources. Ensure your environment has sufficient CPU cores to handle this without impacting the main training process. + +.. code-block:: python + + import os + + import torch + import torch.distributed as dist + import torch.distributed.checkpoint as dcp + import torch.multiprocessing as mp + import torch.nn as nn + + from torch.distributed.fsdp import fully_shard + from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict + from torch.distributed.checkpoint.stateful import Stateful + from torch.distributed.checkpoint.staging import DefaultStager + from torch.nn.modules.linear import NonDynamicallyQuantizableLinear + + CHECKPOINT_DIR = "checkpoint" + + + class AppState(Stateful): + """This is a useful wrapper for checkpointing the Application State. Since this object is compliant + with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the + dcp.save/load APIs. + + Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model + and optimizer. + """ + + def __init__(self, model, optimizer=None): + self.model = model + self.optimizer = optimizer + + def state_dict(self): + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + return { + "model": model_state_dict, + "optim": optimizer_state_dict + } + + def load_state_dict(self, state_dict): + # sets our state dicts on the model and optimizer, now that we've loaded + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"] + ) + + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(16, 16) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 8) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + + def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355 " + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + + def cleanup(): + dist.destroy_process_group() + + + def run_fsdp_checkpoint_save_example(rank, world_size): + print(f"Running basic FSDP checkpoint saving example on rank {rank}.") + setup(rank, world_size) + + # create a model and move it to GPU with id rank + model = ToyModel().to(rank) + model = fully_shard(model) + + loss_fn = nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + checkpoint_future = None + for step in range(10): + print(f"Step {step} starting...") + optimizer.zero_grad() + model(torch.rand(8, 16, device="cuda")).sum().backward() + + # Critical: We must ensure the previous checkpoint's D2H copy (staging) + # is complete before the optimizer modifies the model parameters. + # Placing this await AFTER the backward pass allows us to overlap + # the D2H copy with the current step's Forward and Backward computation. + if checkpoint_future is not None: + checkpoint_future.staging_completion.result() + optimizer.step() + + # waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time + if checkpoint_future is not None: + checkpoint_future.upload_completion.result() + + state_dict = { "app": AppState(model, optimizer) } + + # Pass the DefaultStager to enable fully asynchronous staging. + # This offloads the state_dict creation and GPU-to-CPU copy to a background thread. + # The return object (AsyncSaveResponse) exposes distinct futures for staging and upload. + checkpoint_future = dcp.async_save( + state_dict, + checkpoint_id=f"{CHECKPOINT_DIR}_step{step}", + async_stager=DefaultStager(), + ) + + # Ensure the last checkpoint completes + if checkpoint_future: + checkpoint_future.upload_completion.result() + + cleanup() + + + if __name__ == "__main__": + world_size = torch.cuda.device_count() + print(f"Running async checkpoint example on {world_size} devices.") + mp.spawn( + run_fsdp_checkpoint_save_example, + args=(world_size,), + nprocs=world_size, + join=True, + ) + Conclusion ---------- In conclusion, we have learned how to use DCP's :func:`async_save` API to generate checkpoints off the critical training path. We've also learned about the