Skip to content

Commit 6cf8d29

Browse files
authored
[Bugfix] Disable lm_head output device movement for multigpu dispatch (#2108)
## Purpose ## * As of #2081, the lm_head now produces outputs on the meta device. However, in the case of multi-gpu dispatch, accelerate will try to move lm_head outputs to the model input device. This behavior needs to be disabled <details><summary>Traceback</summary> ``` @pytest.mark.integration def test_infer_owl_layer_sparsity(): target_sparsity = 0.7 vocab_size = 512 seq_len = 2048 ds_size = 16 with create_session() as session: session.initialize() modifier = SparseGPTModifier( sparsity=0.7, sparsity_profile="owl", owl_m=5, owl_lmbda=0.05 ) model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") dataset = Dataset.from_dict( {"input_ids": torch.randint(0, vocab_size, (ds_size, seq_len))} ) dataloader = format_calibration_data(dataset) sequential_targets = modifier._infer_sequential_targets(model) layers = get_layers(sequential_targets, model) > sparsities = modifier._infer_owl_layer_sparsity(model, layers, dataloader) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ tests/llmcompressor/transformers/sparsegpt/test_sparsegpt_owl.py:33: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py:210: in _infer_owl_layer_sparsity activations = self._get_activations(model, dataloader) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py:269: in _get_activations run_calibration(model, dataloader) src/llmcompressor/pipelines/basic/pipeline.py:58: in run_calibration pipeline(model, dataloader, None) src/llmcompressor/pipelines/basic/pipeline.py:51: in __call__ model(**batch) ../venv/llmcomp-latest/lib/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ../venv/llmcomp-latest/lib/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ../venv/llmcomp-latest/lib/python3.12/site-packages/accelerate/hooks.py:176: in new_forward return module._hf_hook.post_forward(module, output) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ../venv/llmcomp-latest/lib/python3.12/site-packages/accelerate/hooks.py:402: in post_forward output = send_to_device(output, self.input_device, skip_keys=self.skip_keys) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ../venv/llmcomp-latest/lib/python3.12/site-packages/accelerate/utils/operations.py:180: in send_to_device k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor = tensor(..., device='meta', size=(1, 2048, 1035)), device = device(type='cuda', index=0), non_blocking = False, skip_keys = [] def send_to_device(tensor, device, non_blocking=False, skip_keys=None): """ Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. Args: tensor (nested list/tuple/dictionary of `torch.Tensor`): The data to send to a given device. device (`torch.device`): The device to send the data to. Returns: The same data structure as `tensor` with all tensors sent to the proper device. """ if is_torch_tensor(tensor) or hasattr(tensor, "to"): # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)). if device == "npu": device = "npu:0" try: > return tensor.to(device, non_blocking=non_blocking) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E NotImplementedError: Cannot copy out of meta tensor; no data! ../venv/llmcomp-latest/lib/python3.12/site-packages/accelerate/utils/operations.py:154: NotImplementedError ``` </details> ## Changes ## * Disable model output device movement when lm_head is disabled ## Testing ## * Confirmed that failing test `test_sparsegpt_owl.py` now passes when dispatched with multiple gpus Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent b46655b commit 6cf8d29

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/llmcompressor/utils/helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,13 @@ def disable_lm_head(model: torch.nn.Module):
11191119
def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
11201120
return input.to("meta") @ dummy_weight.T
11211121

1122-
with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)):
1122+
with contextlib.ExitStack() as stack:
1123+
lm_head_forward = dummy_forward.__get__(lm_head)
1124+
stack.enter_context(patch_attr(lm_head, "forward", lm_head_forward))
1125+
1126+
if hasattr(model, "_hf_hook"):
1127+
stack.enter_context(patch_attr(model._hf_hook, "io_same_device", False))
1128+
11231129
yield
11241130

11251131

0 commit comments

Comments
 (0)