Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 149 additions & 1 deletion recipes_source/distributed_async_checkpoint_recipe.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Asynchronous Saving with Distributed Checkpoint (DCP)
=====================================================

**Author:** `Lucas Pasqualin <https://github.com/lucasllc>`__, `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__
**Author:** `Lucas Pasqualin <https://github.com/lucasllc>`__, `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__, `Yunsheng Ni <https://github.com/niyunsheng>`__

Checkpointing is often a bottle-neck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
One excellent strategy for offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example
Expand Down Expand Up @@ -279,6 +279,154 @@ checkpoint requests users can take advantage of direct memory access to speed up
)


Fully Asynchronous Staging with DefaultStager
--------------------------------------------

.. versionadded:: 2.9
The ``async_stager`` argument and ``DefaultStager`` class were introduced in PyTorch 2.9.

While ``async_save`` handles the disk write asynchronously, the process of copying data from GPU to CPU (known as "staging") typically happens on the main thread. Even with Pinned Memory, this Device-to-Host (D2H) copy can block the training loop for large models.

To achieve maximum overlap between computation and checkpointing, we can use the ``DefaultStager``. This component offloads the state dictionary creation and the D2H copy to a background thread.

**Timeline Comparison:**

* **Standard async_save:** ``[GPU Compute] -> [CPU Copy (Blocking)] -> [Disk Write (Async)]``
* **With AsyncStager:** ``[GPU Compute] || [CPU Copy (Async)] -> [Disk Write (Async)]``

.. note::
Using ``AsyncStager`` introduces a background thread that consumes CPU resources. Ensure your environment has sufficient CPU cores to handle this without impacting the main training process.

.. code-block:: python

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.staging import DefaultStager
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.

Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""

def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer

def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}

def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)

class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "

# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


def cleanup():
dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)

# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

checkpoint_future = None
for step in range(10):
print(f"Step {step} starting...")
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()

# Critical: We must ensure the previous checkpoint's D2H copy (staging)
# is complete before the optimizer modifies the model parameters.
# Placing this await AFTER the backward pass allows us to overlap
# the D2H copy with the current step's Forward and Backward computation.
if checkpoint_future is not None:
checkpoint_future.staging_completion.result()
optimizer.step()

# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
if checkpoint_future is not None:
checkpoint_future.upload_completion.result()

state_dict = { "app": AppState(model, optimizer) }

# Pass the DefaultStager to enable fully asynchronous staging.
# This offloads the state_dict creation and GPU-to-CPU copy to a background thread.
# The return object (AsyncSaveResponse) exposes distinct futures for staging and upload.
checkpoint_future = dcp.async_save(
state_dict,
checkpoint_id=f"{CHECKPOINT_DIR}_step{step}",
async_stager=DefaultStager(),
)

# Ensure the last checkpoint completes
if checkpoint_future:
checkpoint_future.upload_completion.result()

cleanup()


if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running async checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)

Conclusion
----------
In conclusion, we have learned how to use DCP's :func:`async_save` API to generate checkpoints off the critical training path. We've also learned about the
Expand Down