diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0e5ca71b3..0b73970b4 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -197,13 +197,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: architecture=state.model.__class__.__name__ ) - self._set_resolved_mappings(state.model) - return True def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + self._set_resolved_mappings(state.model) + # register quantization calibration hooks # assume quantization has been initialized by this modifier or one before it QuantizationMixin.start_calibration(self, state.model) diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py index a2f60a33f..7b171edc4 100644 --- a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py @@ -109,15 +109,19 @@ def on_initialize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ + # infer module and sequential targets + self.sequential_targets = self._infer_sequential_targets(state.model) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + # find target layers model: torch.nn.Module = state.model dataloader: torch.utils.data.DataLoader = state.data.calib - - # infer module and sequential targets - self.sequential_targets = self._infer_sequential_targets(model) layers = get_layers(self.sequential_targets, model) - self._target_layers = get_layers( - self.targets, model - ) # layers containing targets + self._target_layers = get_layers(self.targets, model) # infer layer sparsities if self.sparsity_profile == "owl": @@ -127,7 +131,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader) - # get layers and validate sparsity + # validate sparsity if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len( self.sparsity ): @@ -136,11 +140,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: f"sparsities values, but model has {len(layers)} target layers" ) - return True - - def on_start(self, state: State, event: Event, **kwargs): - self.started_ = True - # register hooks for index, (layer_name, layer) in enumerate(self._target_layers.items()): match self.sparsity: diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ab23e4fad..eb1270553 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -163,6 +163,11 @@ def on_initialize(self, state: State, **kwargs) -> bool: if QuantizationMixin.has_config(self): QuantizationMixin.initialize_quantization(self, state.model) + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + # prepare module names self._module_names = { m: name @@ -171,11 +176,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: ) } - return True - - def on_start(self, state: State, event: Event, **kwargs): - self.started_ = True - # register quantization calibration hooks # assume quantization has been initialized by this modifier or one before it QuantizationMixin.start_calibration(self, state.model) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 908d03254..886e76ebf 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -131,13 +131,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: ) self.ignore = [] if not self.ignore else self.ignore self.mappings = self._infer_mappings_from_model(state.model) - self.resolved_mappings_ = self._resolve_mappings(state.model) self.scales_ = {} return True def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + self.resolved_mappings_ = self._resolve_mappings(state.model) self._setup_scale_hooks() def on_event(self, state: State, event: Event, **kwargs): diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 4a776a7d5..cd7d3fd52 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -6,13 +6,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple import torch -from accelerate.hooks import remove_hook_from_module -from compressed_tensors.utils import ( - has_offloaded_params, - offloaded_dispatch, - patch_attr, - remove_dispatch, -) +from compressed_tensors.utils import patch_attr from compressed_tensors.utils.match import match_targets from loguru import logger from torch.fx import Graph, GraphModule, Node @@ -37,7 +31,6 @@ "trace_subgraphs", "Subgraph", "get_sequential_targets", - "dispatch_for_sequential", ] @@ -104,10 +97,9 @@ def trace_subgraphs( # find modules targets = match_modules(model, sequential_targets) ancestors = get_sequential_ancestors(model, targets) - offloaded = set(m for m in model.modules() if has_offloaded_params(m)) # initialize arguments - tracer = SequentialTracer(ancestors, offloaded) + tracer = SequentialTracer(ancestors) concrete_args = populate_concrete_args(model, sample_input) with contextlib.ExitStack() as stack: @@ -168,32 +160,18 @@ class SequentialTracer(HFTracer): """ Get a tracer specialized for the given model. The resulting tracer will not trace inside of sequential targets, nor any modules which are not call graph ancestors of - sequential targets - - Tracing within sequential targets is unnecessary, and tracing within offloaded - modules may result in meta tensors being added to the model graph + sequential targets. Tracing outside of call ancestors of sequential targets will be + skipped :param ancestors: modules which are ancestors of sequential targets - :param offloaded: modules which have offloaded params and should not be traced """ - def __init__(self, ancestors: Set[Module], offloaded: Set[Module]): + def __init__(self, ancestors: Set[Module]): self.ancestors = ancestors - self.offloaded = offloaded # skip any mask creation functions not already caught by the autowrapper super().__init__(autowrap_functions=_get_autowrap_functions()) - # check unlikely case that ancestors have direct params which are offloaded - offloaded_ancestors = offloaded & ancestors - for ancestor in offloaded_ancestors: - remove_hook_from_module(ancestor, recurse=False) - self.offloaded.remove(ancestor) - logger.warning( - f"Direct parameters attached to {ancestor.__class__.__name__} have " - "been onloaded in order to ensure safe graph capture and execution" - ) - def create_arg(self, a: Any) -> Argument: # special extension allows models which depend on config values to be traced if isinstance(a, PretrainedConfig): @@ -204,8 +182,8 @@ def create_arg(self, a: Any) -> Argument: return super().create_arg(a) def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: - # do not trace non-ancestors or modules with offloaded params - return module not in self.ancestors or module in self.offloaded + # do not trace non-ancestors + return module not in self.ancestors def populate_concrete_args(model: Module, sample_input: Dict) -> Dict: @@ -526,27 +504,6 @@ def is_ancestor(module: Module) -> bool: return ancestors -def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel: - """ - Dispatch a model for sequential calibration using a sequential pipeline. - The model will be offloaded to the CPU and dispatched to CUDA/XPU device - if available. Removes any existing hooks. - - :param model: model to dispatch - :return: dispatched model - """ - remove_dispatch(model) - - if torch.cuda.is_available(): - offloaded_dispatch(model, execution_device=torch.device("cuda:0")) - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - offloaded_dispatch(model, execution_device=torch.device("xpu:0")) - else: - logger.warning("CUDA/XPU is not available! Compressing model on CPU instead") - - return model - - def _get_autowrap_functions() -> Tuple[Callable[[Any], Any], ...]: try: from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 91516f280..f8ccf1d0a 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING import torch -from compressed_tensors.utils import disable_offloading, get_execution_device +from compressed_tensors.offload import offload_model +from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -11,10 +12,10 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pipelines.sequential.helpers import ( - dispatch_for_sequential, get_sequential_targets, trace_subgraphs, ) +from llmcompressor.utils.dev import get_main_device from llmcompressor.utils.helpers import ( DISABLE_QAC_MODIFIERS, DisableQuantization, @@ -59,10 +60,6 @@ def __call__( """ session = active_session() - # prepare model for sequential onloading - dispatch_for_sequential(model) - model_device = get_execution_device(model) - # prepare to trace subgraphs modifiers = session.lifecycle.recipe.modifiers sequential_targets = get_sequential_targets(modifiers, model, dataset_args) @@ -73,6 +70,10 @@ def __call__( subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) num_subgraphs = len(subgraphs) + # prepare model for sequential onloading + model_device = get_main_device() + offload_model(model, onload_device=model_device, offload_device="cpu") + LifecycleCallbacks.calibration_epoch_start() # TODO: remove this to enable quantization aware calibration diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index a227ffa06..8cdedadf5 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -2,13 +2,14 @@ import logging import os import tempfile +from functools import wraps from typing import Type import torch -from accelerate import dispatch_model, infer_auto_device_map -from accelerate.utils import get_balanced_memory -from compressed_tensors.utils import patch_attr, remove_dispatch +from compressed_tensors.offload import dispatch_model +from compressed_tensors.utils import patch_attr from huggingface_hub import snapshot_download +from loguru import logger from safetensors.torch import save_file from transformers import AutoModelForCausalLM, PreTrainedModel from transformers.modeling_utils import TORCH_INIT_FUNCTIONS @@ -17,6 +18,7 @@ __all__ = [ "skip_weights_download", "patch_transformers_logger_level", + "get_main_device", "dispatch_for_generation", ] @@ -113,28 +115,30 @@ def patch_transformers_logger_level(level: int = logging.ERROR): transformers_logger.setLevel(level=restore_log_level) -def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel: +def get_main_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda:0") + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu:0") + else: + logger.warning("CUDA/XPU is not available! Compressing model on CPU instead") + return torch.device("cpu") + + +@wraps(dispatch_model) +def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel: """ Dispatch a model autoregressive generation. This means that modules are dispatched - evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks - that may have existed previously. + evenly across avaiable devices and kept onloaded if possible. :param model: model to dispatch - :return: model which is dispatched + :param hint_batch_size: reserve memory for batch size of inputs + :param hint_batch_seq_len: reserve memory for sequence of length of inputs + :param hint_model_dtype: reserve memory for model's dtype. + Will be inferred from model if none is provided + :param hint_extra_memory: extra memory reserved for model serving + :param no_split_modules: names of module classes which should not be split + across multiple devices + :return: dispatched model """ - remove_dispatch(model) - - no_split_module_classes = model._get_no_split_modules("auto") - max_memory = get_balanced_memory( - model, - dtype=model.dtype, - no_split_module_classes=no_split_module_classes, - ) - device_map = infer_auto_device_map( - model, - dtype=model.dtype, - max_memory=max_memory, - no_split_module_classes=no_split_module_classes, - ) - - return dispatch_model(model, device_map=device_map) + return dispatch_model(*args, **kwargs) diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py index 9b4831a29..b8e61f725 100644 --- a/src/llmcompressor/utils/transformers.py +++ b/src/llmcompressor/utils/transformers.py @@ -1,5 +1,4 @@ import torch -from compressed_tensors import has_offloaded_params, register_offload_parameter from loguru import logger from torch.nn import Parameter from transformers import PreTrainedModel @@ -28,14 +27,9 @@ def untie_word_embeddings(model: PreTrainedModel): # clone data to untie for module in (input_embed, output_embed): - if not has_offloaded_params(module): - data = module.weight.data - else: - data = module._hf_hook.weights_map["weight"] - - requires_grad = module.weight.requires_grad - untied_param = Parameter(data.clone(), requires_grad=requires_grad) - register_offload_parameter(module, "weight", untied_param) + weight = module.weight + param = Parameter(weight.data.clone(), requires_grad=weight.requires_grad) + module.register_parameter("weight", param) # modify model config if hasattr(model.config, "tie_word_embeddings"): diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index b6535bf72..8a58ba0f5 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -1,11 +1,11 @@ import pytest import torch +from compressed_tensors.offload import dispatch_model from transformers import ( AutoModelForCausalLM, MllamaForConditionalGeneration, ) -from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential from llmcompressor.utils import ( ALL_TOKEN, DisableQuantization, @@ -17,7 +17,7 @@ interpolate, validate_str_iterable, ) -from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download +from llmcompressor.utils.dev import skip_weights_download from tests.testing_utils import requires_gpu @@ -153,14 +153,12 @@ def test_disable_cache(model_cls, model_stub): @requires_gpu -@pytest.mark.parametrize("offload", ["sequential", "basic", "none"]) -def test_disable_lm_head(offload): +@pytest.mark.parametrize("dispatch", (True, False)) +def test_disable_lm_head(dispatch): model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") - if offload == "sequential": - dispatch_for_sequential(model) - if offload == "basic": - dispatch_for_generation(model) - if offload == "none": + if dispatch: + dispatch_model(model, "cuda") + else: model = model.to("cuda") lm_input_device = None