From 1e0d49c3bb7bf4d08177df3404120cfd7c2d8638 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Dec 2025 22:36:59 +0000 Subject: [PATCH 1/5] dispatch after tracing Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/sequential/pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 91516f280..1bb2e8a0e 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -59,10 +59,6 @@ def __call__( """ session = active_session() - # prepare model for sequential onloading - dispatch_for_sequential(model) - model_device = get_execution_device(model) - # prepare to trace subgraphs modifiers = session.lifecycle.recipe.modifiers sequential_targets = get_sequential_targets(modifiers, model, dataset_args) @@ -73,6 +69,10 @@ def __call__( subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) num_subgraphs = len(subgraphs) + # prepare model for sequential onloading + dispatch_for_sequential(model) + model_device = get_execution_device(model) + LifecycleCallbacks.calibration_epoch_start() # TODO: remove this to enable quantization aware calibration From 7f0234aabeb8e908a87aeeda23f8284fb004eed4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Dec 2025 05:39:30 +0000 Subject: [PATCH 2/5] necessary changes Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/awq/base.py | 4 +- .../modifiers/pruning/sparsegpt/sgpt_base.py | 23 ++++---- .../modifiers/quantization/gptq/base.py | 10 ++-- .../modifiers/smoothquant/base.py | 2 +- .../pipelines/sequential/helpers.py | 57 +++---------------- .../pipelines/sequential/pipeline.py | 8 +-- src/llmcompressor/utils/transformers.py | 12 +--- tests/llmcompressor/utils/test_helpers.py | 16 +++--- 8 files changed, 40 insertions(+), 92 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0e5ca71b3..0b73970b4 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -197,13 +197,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: architecture=state.model.__class__.__name__ ) - self._set_resolved_mappings(state.model) - return True def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + self._set_resolved_mappings(state.model) + # register quantization calibration hooks # assume quantization has been initialized by this modifier or one before it QuantizationMixin.start_calibration(self, state.model) diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py index a2f60a33f..7b171edc4 100644 --- a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py @@ -109,15 +109,19 @@ def on_initialize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ + # infer module and sequential targets + self.sequential_targets = self._infer_sequential_targets(state.model) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + # find target layers model: torch.nn.Module = state.model dataloader: torch.utils.data.DataLoader = state.data.calib - - # infer module and sequential targets - self.sequential_targets = self._infer_sequential_targets(model) layers = get_layers(self.sequential_targets, model) - self._target_layers = get_layers( - self.targets, model - ) # layers containing targets + self._target_layers = get_layers(self.targets, model) # infer layer sparsities if self.sparsity_profile == "owl": @@ -127,7 +131,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader) - # get layers and validate sparsity + # validate sparsity if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len( self.sparsity ): @@ -136,11 +140,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: f"sparsities values, but model has {len(layers)} target layers" ) - return True - - def on_start(self, state: State, event: Event, **kwargs): - self.started_ = True - # register hooks for index, (layer_name, layer) in enumerate(self._target_layers.items()): match self.sparsity: diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ab23e4fad..eb1270553 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -163,6 +163,11 @@ def on_initialize(self, state: State, **kwargs) -> bool: if QuantizationMixin.has_config(self): QuantizationMixin.initialize_quantization(self, state.model) + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + # prepare module names self._module_names = { m: name @@ -171,11 +176,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: ) } - return True - - def on_start(self, state: State, event: Event, **kwargs): - self.started_ = True - # register quantization calibration hooks # assume quantization has been initialized by this modifier or one before it QuantizationMixin.start_calibration(self, state.model) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 908d03254..886e76ebf 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -131,13 +131,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: ) self.ignore = [] if not self.ignore else self.ignore self.mappings = self._infer_mappings_from_model(state.model) - self.resolved_mappings_ = self._resolve_mappings(state.model) self.scales_ = {} return True def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + self.resolved_mappings_ = self._resolve_mappings(state.model) self._setup_scale_hooks() def on_event(self, state: State, event: Event, **kwargs): diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 4a776a7d5..cd7d3fd52 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -6,13 +6,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple import torch -from accelerate.hooks import remove_hook_from_module -from compressed_tensors.utils import ( - has_offloaded_params, - offloaded_dispatch, - patch_attr, - remove_dispatch, -) +from compressed_tensors.utils import patch_attr from compressed_tensors.utils.match import match_targets from loguru import logger from torch.fx import Graph, GraphModule, Node @@ -37,7 +31,6 @@ "trace_subgraphs", "Subgraph", "get_sequential_targets", - "dispatch_for_sequential", ] @@ -104,10 +97,9 @@ def trace_subgraphs( # find modules targets = match_modules(model, sequential_targets) ancestors = get_sequential_ancestors(model, targets) - offloaded = set(m for m in model.modules() if has_offloaded_params(m)) # initialize arguments - tracer = SequentialTracer(ancestors, offloaded) + tracer = SequentialTracer(ancestors) concrete_args = populate_concrete_args(model, sample_input) with contextlib.ExitStack() as stack: @@ -168,32 +160,18 @@ class SequentialTracer(HFTracer): """ Get a tracer specialized for the given model. The resulting tracer will not trace inside of sequential targets, nor any modules which are not call graph ancestors of - sequential targets - - Tracing within sequential targets is unnecessary, and tracing within offloaded - modules may result in meta tensors being added to the model graph + sequential targets. Tracing outside of call ancestors of sequential targets will be + skipped :param ancestors: modules which are ancestors of sequential targets - :param offloaded: modules which have offloaded params and should not be traced """ - def __init__(self, ancestors: Set[Module], offloaded: Set[Module]): + def __init__(self, ancestors: Set[Module]): self.ancestors = ancestors - self.offloaded = offloaded # skip any mask creation functions not already caught by the autowrapper super().__init__(autowrap_functions=_get_autowrap_functions()) - # check unlikely case that ancestors have direct params which are offloaded - offloaded_ancestors = offloaded & ancestors - for ancestor in offloaded_ancestors: - remove_hook_from_module(ancestor, recurse=False) - self.offloaded.remove(ancestor) - logger.warning( - f"Direct parameters attached to {ancestor.__class__.__name__} have " - "been onloaded in order to ensure safe graph capture and execution" - ) - def create_arg(self, a: Any) -> Argument: # special extension allows models which depend on config values to be traced if isinstance(a, PretrainedConfig): @@ -204,8 +182,8 @@ def create_arg(self, a: Any) -> Argument: return super().create_arg(a) def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: - # do not trace non-ancestors or modules with offloaded params - return module not in self.ancestors or module in self.offloaded + # do not trace non-ancestors + return module not in self.ancestors def populate_concrete_args(model: Module, sample_input: Dict) -> Dict: @@ -526,27 +504,6 @@ def is_ancestor(module: Module) -> bool: return ancestors -def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel: - """ - Dispatch a model for sequential calibration using a sequential pipeline. - The model will be offloaded to the CPU and dispatched to CUDA/XPU device - if available. Removes any existing hooks. - - :param model: model to dispatch - :return: dispatched model - """ - remove_dispatch(model) - - if torch.cuda.is_available(): - offloaded_dispatch(model, execution_device=torch.device("cuda:0")) - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - offloaded_dispatch(model, execution_device=torch.device("xpu:0")) - else: - logger.warning("CUDA/XPU is not available! Compressing model on CPU instead") - - return model - - def _get_autowrap_functions() -> Tuple[Callable[[Any], Any], ...]: try: from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 1bb2e8a0e..d0e048847 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING import torch -from compressed_tensors.utils import disable_offloading, get_execution_device +from compressed_tensors.offload import dispatch_model +from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -11,7 +12,6 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pipelines.sequential.helpers import ( - dispatch_for_sequential, get_sequential_targets, trace_subgraphs, ) @@ -70,8 +70,8 @@ def __call__( num_subgraphs = len(subgraphs) # prepare model for sequential onloading - dispatch_for_sequential(model) - model_device = get_execution_device(model) + model_device = "cuda" if torch.cuda.is_available() else "cpu" + dispatch_model(model, model_device) LifecycleCallbacks.calibration_epoch_start() diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py index 9b4831a29..b8e61f725 100644 --- a/src/llmcompressor/utils/transformers.py +++ b/src/llmcompressor/utils/transformers.py @@ -1,5 +1,4 @@ import torch -from compressed_tensors import has_offloaded_params, register_offload_parameter from loguru import logger from torch.nn import Parameter from transformers import PreTrainedModel @@ -28,14 +27,9 @@ def untie_word_embeddings(model: PreTrainedModel): # clone data to untie for module in (input_embed, output_embed): - if not has_offloaded_params(module): - data = module.weight.data - else: - data = module._hf_hook.weights_map["weight"] - - requires_grad = module.weight.requires_grad - untied_param = Parameter(data.clone(), requires_grad=requires_grad) - register_offload_parameter(module, "weight", untied_param) + weight = module.weight + param = Parameter(weight.data.clone(), requires_grad=weight.requires_grad) + module.register_parameter("weight", param) # modify model config if hasattr(model.config, "tie_word_embeddings"): diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index b6535bf72..8a58ba0f5 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -1,11 +1,11 @@ import pytest import torch +from compressed_tensors.offload import dispatch_model from transformers import ( AutoModelForCausalLM, MllamaForConditionalGeneration, ) -from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential from llmcompressor.utils import ( ALL_TOKEN, DisableQuantization, @@ -17,7 +17,7 @@ interpolate, validate_str_iterable, ) -from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download +from llmcompressor.utils.dev import skip_weights_download from tests.testing_utils import requires_gpu @@ -153,14 +153,12 @@ def test_disable_cache(model_cls, model_stub): @requires_gpu -@pytest.mark.parametrize("offload", ["sequential", "basic", "none"]) -def test_disable_lm_head(offload): +@pytest.mark.parametrize("dispatch", (True, False)) +def test_disable_lm_head(dispatch): model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") - if offload == "sequential": - dispatch_for_sequential(model) - if offload == "basic": - dispatch_for_generation(model) - if offload == "none": + if dispatch: + dispatch_model(model, "cuda") + else: model = model.to("cuda") lm_input_device = None From 01c51e943a0a08d173576018a9219236734d38f7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Dec 2025 05:55:07 +0000 Subject: [PATCH 3/5] dispatching Signed-off-by: Kyle Sayers --- .../pipelines/sequential/helpers.py | 25 +++++++++++++++++++ .../pipelines/sequential/pipeline.py | 7 +++--- src/llmcompressor/utils/dev.py | 21 +++++++++++++++- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index cd7d3fd52..75cee8a8b 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple import torch +from compressed_tensors.offload import dispatch_model from compressed_tensors.utils import patch_attr from compressed_tensors.utils.match import match_targets from loguru import logger @@ -19,6 +20,7 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.sequential.transformers_helpers import HFTracer +from llmcompressor.utils.dev import remove_dispatch from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import get_no_split_params @@ -504,6 +506,29 @@ def is_ancestor(module: Module) -> bool: return ancestors +def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel: + """ + Dispatch a model for sequential calibration using a sequential pipeline. + The model will be offloaded to the CPU and dispatched to CUDA/XPU device + if available. Removes any existing hooks. + + :param model: model to dispatch + :return: dispatched model + """ + if torch.cuda.is_available(): + model_device = "cuda:0" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + model_device = "xpu:0" + else: + logger.warning("CUDA/XPU is not available! Compressing model on CPU instead") + model_device = "cpu" + + remove_dispatch(model) # remove accelerate dispatches + model = dispatch_model(model, model_device) + + return model + + def _get_autowrap_functions() -> Tuple[Callable[[Any], Any], ...]: try: from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index d0e048847..6bc379a4e 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING import torch -from compressed_tensors.offload import dispatch_model +from compressed_tensors.offload import get_execution_device from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -19,6 +19,7 @@ DISABLE_QAC_MODIFIERS, DisableQuantization, calibration_forward_context, + dispatch_for_sequential, ) if TYPE_CHECKING: @@ -70,8 +71,8 @@ def __call__( num_subgraphs = len(subgraphs) # prepare model for sequential onloading - model_device = "cuda" if torch.cuda.is_available() else "cpu" - dispatch_model(model, model_device) + dispatch_for_sequential(model) + model_device = get_execution_device(model) LifecycleCallbacks.calibration_epoch_start() diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index a227ffa06..4c2641d05 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -6,8 +6,10 @@ import torch from accelerate import dispatch_model, infer_auto_device_map +from accelerate.hooks import remove_hook_from_module from accelerate.utils import get_balanced_memory -from compressed_tensors.utils import patch_attr, remove_dispatch +from compressed_tensors.offload import remove_dispatch as remove_torch_offload_dispatch +from compressed_tensors.utils import patch_attr from huggingface_hub import snapshot_download from safetensors.torch import save_file from transformers import AutoModelForCausalLM, PreTrainedModel @@ -18,6 +20,7 @@ "skip_weights_download", "patch_transformers_logger_level", "dispatch_for_generation", + "remove_dispatch", ] @@ -123,6 +126,7 @@ def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel: :return: model which is dispatched """ remove_dispatch(model) + remove_torch_offload_dispatch(model) no_split_module_classes = model._get_no_split_modules("auto") max_memory = get_balanced_memory( @@ -138,3 +142,18 @@ def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel: ) return dispatch_model(model, device_map=device_map) + + +def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module: + """ + Remove any existing accelerate dispatches from module + + :param module: module which may be dispatched with hf hooks + :return: module without dispatch + """ + remove_hook_from_module(module, recurse=True) + if hasattr(module, "hf_device_map"): + delattr(module, "hf_device_map") + module.to("cpu") + + return module From 677030054470c57f65e9d67a03f15b14a3c7a50b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 19 Dec 2025 02:12:31 +0000 Subject: [PATCH 4/5] WIP: awq Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/awq/base.py | 147 +++++++++--------- .../pipelines/sequential/helpers.py | 25 --- .../pipelines/sequential/pipeline.py | 8 +- src/llmcompressor/utils/dev.py | 68 ++++---- 4 files changed, 104 insertions(+), 144 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0b73970b4..349887d2a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -10,7 +10,6 @@ ) from compressed_tensors.quantization.utils import strategy_cdiv from compressed_tensors.utils import ( - align_modules, get_execution_device, get_lowest_common_ancestor_name, getattr_chain, @@ -409,91 +408,90 @@ def _apply_smoothing(self, model: Module) -> None: balance_layers = mapping.balance_layers parent_module = mapping.parent - with ( - align_modules([parent_module, smooth_layer, *balance_layers]), - calibration_forward_context(model), - HooksMixin.disable_hooks(), + # Compute output of unquantized module + fp16_outputs = self._run_samples(model, parent_module) + if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): + logger.info( + f"Skipping smooth_layer {mapping.smooth_name}, no activations " + "found to scale. This can occasionally occur in MoE models " + "when certain experts are not activated by calibration samples." + ) + del self._smooth_activation_means[mapping.smooth_name] + continue + if not all( + [fp16_output.isfinite().all() for fp16_output in fp16_outputs] ): - # Compute output of unquantized module - fp16_outputs = self._run_samples(parent_module) - if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): - logger.info( - f"Skipping smooth_layer {mapping.smooth_name}, no activations " - "found to scale. This can occasionally occur in MoE models " - "when certain experts are not activated by calibration samples." - ) - del self._smooth_activation_means[mapping.smooth_name] - continue - if not all( - [fp16_output.isfinite().all() for fp16_output in fp16_outputs] - ): - logger.warning( - f"Skipping smooth_layer {mapping.smooth_name}, NaN or inf " - "outputs found during forward pass of the parent module " - f"{mapping.parent_name}. The model is either generating NaN " - "output with provided calibration data set, or the mappings " - "are incorrectly set and modifying the model in undesired " - "ways. If you encounter this consistently, raise an issue at " - "https://github.com/vllm-project/llm-compressor/issues" - ) - del self._smooth_activation_means[mapping.smooth_name] - continue + logger.warning( + f"Skipping smooth_layer {mapping.smooth_name}, NaN or inf " + "outputs found during forward pass of the parent module " + f"{mapping.parent_name}. The model is either generating NaN " + "output with provided calibration data set, or the mappings " + "are incorrectly set and modifying the model in undesired " + "ways. If you encounter this consistently, raise an issue at " + "https://github.com/vllm-project/llm-compressor/issues" + ) + del self._smooth_activation_means[mapping.smooth_name] + continue - best_scales = self._compute_best_scale(mapping, fp16_outputs) + best_scales = self._compute_best_scale(model, mapping, fp16_outputs) - @torch.no_grad() - def _smooth(module: Module): - scales = best_scales.to(module.weight.device) - if module in balance_layers: + @torch.no_grad() + def _smooth(module: Module): + scales = best_scales.to(module.weight.device) + print(scales) + if module in balance_layers: + update_offload_parameter( + module, + "weight", + module.weight.mul_(scales.view(1, -1)), + ) + elif module == smooth_layer: + if module.weight.ndim == 1: + breakpoint() update_offload_parameter( module, "weight", - module.weight.mul_(scales.view(1, -1)), + module.weight.div_(scales), + ) + else: + # NOTE: edge case when smooth layer number of out_features + # is not equal to balance layer number of in_features + # e.g. when fused qkv_proj is used to smooth o_proj + # in this case, default to scaling the last output features + # because the desired smooth layer is v_proj + # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123 + weight = module.weight + weight[-scales.size(0) :].div_(scales.view(-1, 1)) + update_offload_parameter(module, "weight", weight) + if hasattr(module, "bias") and module.bias is not None: + update_offload_parameter( + module, + "bias", + module.bias.div_(scales), ) - elif module == smooth_layer: - if module.weight.ndim == 1: - update_offload_parameter( - module, - "weight", - module.weight.div_(scales), - ) - else: - # NOTE: edge case when smooth layer number of out_features - # is not equal to balance layer number of in_features - # e.g. when fused qkv_proj is used to smooth o_proj - # in this case, default to scaling the last output features - # because the desired smooth layer is v_proj - # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123 - weight = module.weight - weight[-scales.size(0) :].div_(scales.view(-1, 1)) - update_offload_parameter(module, "weight", weight) - if hasattr(module, "bias") and module.bias is not None: - update_offload_parameter( - module, - "bias", - module.bias.div_(scales), - ) - - parent = get_fsdp_parent(mapping.smooth_name, model) - if parent is not None: - parent.apply(_smooth) - else: - # if we're not running with FSDP we can apply smoothing directly - for layer in balance_layers: - _smooth(layer) - _smooth(smooth_layer) - # remove caches needed to smooth this mapping - del self._smooth_activation_means[mapping.smooth_name] + parent = get_fsdp_parent(mapping.smooth_name, model) + if parent is not None: + parent.apply(_smooth) + else: + # if we're not running with FSDP we can apply smoothing directly + for layer in balance_layers: + _smooth(layer) + _smooth(smooth_layer) + + # remove caches needed to smooth this mapping + del self._smooth_activation_means[mapping.smooth_name] for v in self._parent_args_cache.values(): v.batch_intermediates.clear() self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> list[torch.Tensor]: - outputs = [ - module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] - ] + def _run_samples(self, model: Module, module: Module) -> list[torch.Tensor]: + with (HooksMixin.disable_hooks(), calibration_forward_context(model)): + outputs = [ + module(**batch_kwargs) + for batch_kwargs in self._parent_args_cache[module] + ] return [ # If tuple, assume that first argument is the input output[0] if isinstance(output, tuple) else output @@ -502,6 +500,7 @@ def _run_samples(self, module: Module) -> list[torch.Tensor]: def _compute_best_scale( self, + model: Module, mapping: ResolvedMapping, fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: @@ -615,7 +614,7 @@ def _compute_best_scale( ) # W * X - int_w_outputs = self._run_samples(mapping.parent) + int_w_outputs = self._run_samples(model, mapping.parent) # compute mean squared error (L2 norm) loss = self._compute_loss(fp16_outputs, int_w_outputs) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 75cee8a8b..cd7d3fd52 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple import torch -from compressed_tensors.offload import dispatch_model from compressed_tensors.utils import patch_attr from compressed_tensors.utils.match import match_targets from loguru import logger @@ -20,7 +19,6 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.sequential.transformers_helpers import HFTracer -from llmcompressor.utils.dev import remove_dispatch from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import get_no_split_params @@ -506,29 +504,6 @@ def is_ancestor(module: Module) -> bool: return ancestors -def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel: - """ - Dispatch a model for sequential calibration using a sequential pipeline. - The model will be offloaded to the CPU and dispatched to CUDA/XPU device - if available. Removes any existing hooks. - - :param model: model to dispatch - :return: dispatched model - """ - if torch.cuda.is_available(): - model_device = "cuda:0" - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - model_device = "xpu:0" - else: - logger.warning("CUDA/XPU is not available! Compressing model on CPU instead") - model_device = "cpu" - - remove_dispatch(model) # remove accelerate dispatches - model = dispatch_model(model, model_device) - - return model - - def _get_autowrap_functions() -> Tuple[Callable[[Any], Any], ...]: try: from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 6bc379a4e..f8ccf1d0a 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING import torch -from compressed_tensors.offload import get_execution_device +from compressed_tensors.offload import offload_model from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -15,11 +15,11 @@ get_sequential_targets, trace_subgraphs, ) +from llmcompressor.utils.dev import get_main_device from llmcompressor.utils.helpers import ( DISABLE_QAC_MODIFIERS, DisableQuantization, calibration_forward_context, - dispatch_for_sequential, ) if TYPE_CHECKING: @@ -71,8 +71,8 @@ def __call__( num_subgraphs = len(subgraphs) # prepare model for sequential onloading - dispatch_for_sequential(model) - model_device = get_execution_device(model) + model_device = get_main_device() + offload_model(model, onload_device=model_device, offload_device="cpu") LifecycleCallbacks.calibration_epoch_start() diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index 4c2641d05..ad3effa9f 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -4,13 +4,13 @@ import tempfile from typing import Type +from functools import wraps + import torch -from accelerate import dispatch_model, infer_auto_device_map -from accelerate.hooks import remove_hook_from_module -from accelerate.utils import get_balanced_memory -from compressed_tensors.offload import remove_dispatch as remove_torch_offload_dispatch +from compressed_tensors.offload import dispatch_model from compressed_tensors.utils import patch_attr from huggingface_hub import snapshot_download +from loguru import logger from safetensors.torch import save_file from transformers import AutoModelForCausalLM, PreTrainedModel from transformers.modeling_utils import TORCH_INIT_FUNCTIONS @@ -19,8 +19,8 @@ __all__ = [ "skip_weights_download", "patch_transformers_logger_level", + "get_main_device", "dispatch_for_generation", - "remove_dispatch", ] @@ -116,44 +116,30 @@ def patch_transformers_logger_level(level: int = logging.ERROR): transformers_logger.setLevel(level=restore_log_level) -def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel: +def get_main_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda:0") + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu:0") + else: + logger.warning("CUDA/XPU is not available! Compressing model on CPU instead") + return torch.device("cpu") + + +@wraps(dispatch_model) +def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel: """ Dispatch a model autoregressive generation. This means that modules are dispatched - evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks - that may have existed previously. + evenly across avaiable devices and kept onloaded if possible. :param model: model to dispatch - :return: model which is dispatched - """ - remove_dispatch(model) - remove_torch_offload_dispatch(model) - - no_split_module_classes = model._get_no_split_modules("auto") - max_memory = get_balanced_memory( - model, - dtype=model.dtype, - no_split_module_classes=no_split_module_classes, - ) - device_map = infer_auto_device_map( - model, - dtype=model.dtype, - max_memory=max_memory, - no_split_module_classes=no_split_module_classes, - ) - - return dispatch_model(model, device_map=device_map) - - -def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module: + :param hint_batch_size: reserve memory for batch size of inputs + :param hint_batch_seq_len: reserve memory for sequence of length of inputs + :param hint_model_dtype: reserve memory for model's dtype. + Will be inferred from model if none is provided + :param hint_extra_memory: extra memory reserved for model serving + :param no_split_modules: names of module classes which should not be split + across multiple devices + :return: dispatched model """ - Remove any existing accelerate dispatches from module - - :param module: module which may be dispatched with hf hooks - :return: module without dispatch - """ - remove_hook_from_module(module, recurse=True) - if hasattr(module, "hf_device_map"): - delattr(module, "hf_device_map") - module.to("cpu") - - return module + return dispatch_model(*args, **kwargs) \ No newline at end of file From e944f39e94b74541cdff9d29b7fa2e28cbaf3b7d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 19 Dec 2025 02:52:35 +0000 Subject: [PATCH 5/5] reduce diff Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/awq/base.py | 147 ++++++++++++------------ src/llmcompressor/utils/dev.py | 7 +- 2 files changed, 77 insertions(+), 77 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 349887d2a..0b73970b4 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -10,6 +10,7 @@ ) from compressed_tensors.quantization.utils import strategy_cdiv from compressed_tensors.utils import ( + align_modules, get_execution_device, get_lowest_common_ancestor_name, getattr_chain, @@ -408,90 +409,91 @@ def _apply_smoothing(self, model: Module) -> None: balance_layers = mapping.balance_layers parent_module = mapping.parent - # Compute output of unquantized module - fp16_outputs = self._run_samples(model, parent_module) - if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): - logger.info( - f"Skipping smooth_layer {mapping.smooth_name}, no activations " - "found to scale. This can occasionally occur in MoE models " - "when certain experts are not activated by calibration samples." - ) - del self._smooth_activation_means[mapping.smooth_name] - continue - if not all( - [fp16_output.isfinite().all() for fp16_output in fp16_outputs] + with ( + align_modules([parent_module, smooth_layer, *balance_layers]), + calibration_forward_context(model), + HooksMixin.disable_hooks(), ): - logger.warning( - f"Skipping smooth_layer {mapping.smooth_name}, NaN or inf " - "outputs found during forward pass of the parent module " - f"{mapping.parent_name}. The model is either generating NaN " - "output with provided calibration data set, or the mappings " - "are incorrectly set and modifying the model in undesired " - "ways. If you encounter this consistently, raise an issue at " - "https://github.com/vllm-project/llm-compressor/issues" - ) - del self._smooth_activation_means[mapping.smooth_name] - continue + # Compute output of unquantized module + fp16_outputs = self._run_samples(parent_module) + if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): + logger.info( + f"Skipping smooth_layer {mapping.smooth_name}, no activations " + "found to scale. This can occasionally occur in MoE models " + "when certain experts are not activated by calibration samples." + ) + del self._smooth_activation_means[mapping.smooth_name] + continue + if not all( + [fp16_output.isfinite().all() for fp16_output in fp16_outputs] + ): + logger.warning( + f"Skipping smooth_layer {mapping.smooth_name}, NaN or inf " + "outputs found during forward pass of the parent module " + f"{mapping.parent_name}. The model is either generating NaN " + "output with provided calibration data set, or the mappings " + "are incorrectly set and modifying the model in undesired " + "ways. If you encounter this consistently, raise an issue at " + "https://github.com/vllm-project/llm-compressor/issues" + ) + del self._smooth_activation_means[mapping.smooth_name] + continue - best_scales = self._compute_best_scale(model, mapping, fp16_outputs) + best_scales = self._compute_best_scale(mapping, fp16_outputs) - @torch.no_grad() - def _smooth(module: Module): - scales = best_scales.to(module.weight.device) - print(scales) - if module in balance_layers: - update_offload_parameter( - module, - "weight", - module.weight.mul_(scales.view(1, -1)), - ) - elif module == smooth_layer: - if module.weight.ndim == 1: - breakpoint() + @torch.no_grad() + def _smooth(module: Module): + scales = best_scales.to(module.weight.device) + if module in balance_layers: update_offload_parameter( module, "weight", - module.weight.div_(scales), + module.weight.mul_(scales.view(1, -1)), ) - else: - # NOTE: edge case when smooth layer number of out_features - # is not equal to balance layer number of in_features - # e.g. when fused qkv_proj is used to smooth o_proj - # in this case, default to scaling the last output features - # because the desired smooth layer is v_proj - # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123 - weight = module.weight - weight[-scales.size(0) :].div_(scales.view(-1, 1)) - update_offload_parameter(module, "weight", weight) - if hasattr(module, "bias") and module.bias is not None: - update_offload_parameter( - module, - "bias", - module.bias.div_(scales), - ) - - parent = get_fsdp_parent(mapping.smooth_name, model) - if parent is not None: - parent.apply(_smooth) - else: - # if we're not running with FSDP we can apply smoothing directly - for layer in balance_layers: - _smooth(layer) - _smooth(smooth_layer) + elif module == smooth_layer: + if module.weight.ndim == 1: + update_offload_parameter( + module, + "weight", + module.weight.div_(scales), + ) + else: + # NOTE: edge case when smooth layer number of out_features + # is not equal to balance layer number of in_features + # e.g. when fused qkv_proj is used to smooth o_proj + # in this case, default to scaling the last output features + # because the desired smooth layer is v_proj + # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123 + weight = module.weight + weight[-scales.size(0) :].div_(scales.view(-1, 1)) + update_offload_parameter(module, "weight", weight) + if hasattr(module, "bias") and module.bias is not None: + update_offload_parameter( + module, + "bias", + module.bias.div_(scales), + ) + + parent = get_fsdp_parent(mapping.smooth_name, model) + if parent is not None: + parent.apply(_smooth) + else: + # if we're not running with FSDP we can apply smoothing directly + for layer in balance_layers: + _smooth(layer) + _smooth(smooth_layer) - # remove caches needed to smooth this mapping - del self._smooth_activation_means[mapping.smooth_name] + # remove caches needed to smooth this mapping + del self._smooth_activation_means[mapping.smooth_name] for v in self._parent_args_cache.values(): v.batch_intermediates.clear() self._assert_all_activations_consumed() - def _run_samples(self, model: Module, module: Module) -> list[torch.Tensor]: - with (HooksMixin.disable_hooks(), calibration_forward_context(model)): - outputs = [ - module(**batch_kwargs) - for batch_kwargs in self._parent_args_cache[module] - ] + def _run_samples(self, module: Module) -> list[torch.Tensor]: + outputs = [ + module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] + ] return [ # If tuple, assume that first argument is the input output[0] if isinstance(output, tuple) else output @@ -500,7 +502,6 @@ def _run_samples(self, model: Module, module: Module) -> list[torch.Tensor]: def _compute_best_scale( self, - model: Module, mapping: ResolvedMapping, fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: @@ -614,7 +615,7 @@ def _compute_best_scale( ) # W * X - int_w_outputs = self._run_samples(model, mapping.parent) + int_w_outputs = self._run_samples(mapping.parent) # compute mean squared error (L2 norm) loss = self._compute_loss(fp16_outputs, int_w_outputs) diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index ad3effa9f..8cdedadf5 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -2,9 +2,8 @@ import logging import os import tempfile -from typing import Type - from functools import wraps +from typing import Type import torch from compressed_tensors.offload import dispatch_model @@ -139,7 +138,7 @@ def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel: Will be inferred from model if none is provided :param hint_extra_memory: extra memory reserved for model serving :param no_split_modules: names of module classes which should not be split - across multiple devices + across multiple devices :return: dispatched model """ - return dispatch_model(*args, **kwargs) \ No newline at end of file + return dispatch_model(*args, **kwargs)