Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .quantized_tensor import (
restore_from_saved,
prepare_for_saving,
QuantizedTensor,
)


Expand Down Expand Up @@ -255,6 +256,8 @@ def start_offload(self):
Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream.
Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded.
This event is recorded in the start_offload or push_tensor call.

Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor).
"""
self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"])
self.state = "offload_started"
Expand All @@ -275,19 +278,18 @@ def start_offload(self):

with torch.cuda.stream(self.offload_stream):
if allocate_cpu_buffers:
# empty_like is defined also for QuantizedTensors
offloaded_tensor = torch.empty_like(
tensor, device=torch.device("cpu"), pin_memory=True
)
self.cpu_tensor_group.tensor_list.append(offloaded_tensor)
else:
assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, (
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
assert offloaded_tensor.shape == tensor.shape, (
"CPU buffer shape does not match the offloaded tensor shape:"
f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} "
" Make sure that tensor shaped do not change between"
f" {offloaded_tensor.shape} != {tensor.shape} "
"Make sure that tensor shapes do not change between"
" iterations if retain_pinned_cpu_buffers is True."
)
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
offloaded_tensor.copy_(tensor, non_blocking=True)

# aux is a dictionary that contains auxiliary data like information which tensors were deduplicated,
Expand Down Expand Up @@ -318,6 +320,9 @@ def start_reload(self):
"""
Start reloading of tensors.
It allocates new tensors on GPU and puts copy from CPU tasks on offload stream.

Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor
and reconstructed in pop_tensor).
"""
self._validate_state(func_name="start_reload", allowed_states=["offload_finished"])
self.state = "reload_started"
Expand All @@ -330,7 +335,6 @@ def start_reload(self):
# cannot move tensors from pool of one stream to another without
# calling cudaFree and cudaMalloc again.

# empty_like is defined also for QuantizedTensors.
reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda"))
self.offload_stream.wait_stream(torch.cuda.current_stream())

Expand All @@ -347,15 +351,26 @@ def start_reload(self):
self.bwd_gpu_tensor_group
)

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""
It is called when a tensor is saved for backward pass.

If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group.
If tensor is not offloaded, returns the tensor itself.
For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple.
"""
self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"])

# For QuantizedTensor: decompose into component tensors, push each one recursively
if isinstance(tensor, QuantizedTensor):
# Make a copy because prepare_for_saving modifies the object (sets fields to None)
tensor_copy = tensor.detach()
# Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass,
# so the generic prepare_for_saving would not call tensor.prepare_for_saving()
saved_tensors, tensor_obj = tensor_copy.prepare_for_saving()
push_results = [self.push_tensor(t) if t is not None else None for t in saved_tensors]
return (push_results, [tensor_obj])

if self._check_if_offload(tensor):
self.fwd_gpu_tensor_group.tensor_list.append(tensor)
# The group is processed and offloaded at the end of the forward pass of current layer.
Expand All @@ -370,23 +385,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
return len(self.fwd_gpu_tensor_group.tensor_list) - 1
return tensor

def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
def pop_tensor(
self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list]
) -> torch.Tensor:
"""
It is called when a tensor is used in backward pass.
Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish.
For QuantizedTensor (tuple input), reconstructs from component tensors.
"""
self._validate_state(
func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"]
)

# 1. tensor not offloaded
# 1. tensor not offloaded (regular tensor returned as-is from push)
if isinstance(tensor_or_tensor_id, torch.Tensor):
return tensor_or_tensor_id
# 2. the layer was not offloaded at all

# 2. QuantizedTensor case: tuple of (push_results, tensor_objs)
if isinstance(tensor_or_tensor_id, tuple):
push_results, tensor_objs = tensor_or_tensor_id
# Recursively pop each component
reloaded_tensors = [
self.pop_tensor(pr) if pr is not None else None for pr in push_results
]
# Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy
tensor_obj = tensor_objs[0]
tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_obj

# 3. Regular tensor index case
if self.state == "not_offloaded":
return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]

# 3. the layer was offloaded
# 4. the layer was offloaded
assert self.state == "reload_started"
# wait for the tensor to be reloaded
torch.cuda.current_stream().wait_event(
Expand Down Expand Up @@ -419,6 +450,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool:
)
return False

# Only offload tensors with at least 256k elements (~1MB for float32)
if t.numel() < 256 * 1024:
return False
Comment on lines +453 to +455
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand, this is the reason we need to expose an option to disable bulk allocation in split_quantize? Bulk-allocated tensors hold on to memory untill all are deallocated, but this condition means that some small tensor might keep a large memory block alive.


return True
return False

Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list);
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation = false);

/***************************************************************************************************
* Bias gradient fusions
Expand Down
37 changes: 20 additions & 17 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list) {
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation) {
init_extension();

// Check number of tensors
Expand Down Expand Up @@ -996,22 +997,24 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 };
AllocationMethod allocation_method = AllocationMethod::UNFUSED;
QuantizationMethod quantization_method = QuantizationMethod::UNFUSED;
if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsMXFP8Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_MXFP8;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsNVFP4Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_NVFP4;
quantization_method = QuantizationMethod::FUSED_NVFP4;
if (!disable_bulk_allocation) {
if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsMXFP8Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_MXFP8;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsNVFP4Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_NVFP4;
quantization_method = QuantizationMethod::FUSED_NVFP4;
}
}

// Allocate output tensors
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list"));
py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false);
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def forward(
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def forward(
# weights if weights are externally touched outside this module
ctx.weight_object = weight

mark_not_offload(weight, weightmat, bias)
if cpu_offloading:
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
Expand Down