Skip to content

Commit 58e50e0

Browse files
authored
[Misc] Rename offload_sequential_activations to sequential_offload_device (#2134)
## Purpose ## * Enable users to offload activations to another GPU * Because GPU to GPU transfer is must faster than GPU to CPU, there should theoretically be runtime improvements from this option ## Changes ## * Rename `offload_sequential_activations` -> `sequential_offload_device` ## TODO ## * Demonstrate in test that using `cuda:1` leads to runtime improvements --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ef45976 commit 58e50e0

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,12 @@ class DatasetArguments(CustomDatasetArguments):
230230
"definition"
231231
},
232232
)
233-
offload_sequential_activations: bool = field(
234-
default=True,
233+
sequential_offload_device: str = field(
234+
default="cpu",
235235
metadata={
236-
"help": "Whether to offload intermediate activations between sequential "
237-
"layers to the CPU. Disabling offloading is much faster, but uses "
238-
"signficiantly more memory. Default is True."
236+
"help": "Device used to offload intermediate activations between "
237+
"sequential layers. It is recommended to use `cuda:1` if using more "
238+
"than one gpu. Default is cpu."
239239
},
240240
)
241241
quantization_aware_calibration: bool = field(

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ def __call__(
8989
stack.enter_context(DisableQuantization(model))
9090

9191
# prepare intermediates cache
92-
cache_offload = dataset_args.offload_sequential_activations
93-
offload_device = torch.device("cpu") if cache_offload else None
92+
offload_device = torch.device(dataset_args.sequential_offload_device)
9493
activations = IntermediatesCache.from_dataloader(
9594
dataloader, model_device, offload_device=offload_device
9695
)

0 commit comments

Comments
 (0)