Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:
architecture=state.model.__class__.__name__
)

self._set_resolved_mappings(state.model)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

self._set_resolved_mappings(state.model)

# register quantization calibration hooks
# assume quantization has been initialized by this modifier or one before it
QuantizationMixin.start_calibration(self, state.model)
Expand Down
23 changes: 11 additions & 12 deletions src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,19 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

:param state: session state storing input model and calibration data
"""
# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(state.model)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# find target layers
model: torch.nn.Module = state.model
dataloader: torch.utils.data.DataLoader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)
layers = get_layers(self.sequential_targets, model)
self._target_layers = get_layers(
self.targets, model
) # layers containing targets
self._target_layers = get_layers(self.targets, model)

# infer layer sparsities
if self.sparsity_profile == "owl":
Expand All @@ -127,7 +131,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
)
self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)

# get layers and validate sparsity
# validate sparsity
if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len(
self.sparsity
):
Expand All @@ -136,11 +140,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
f"sparsities values, but model has {len(layers)} target layers"
)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# register hooks
for index, (layer_name, layer) in enumerate(self._target_layers.items()):
match self.sparsity:
Expand Down
10 changes: 5 additions & 5 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def on_initialize(self, state: State, **kwargs) -> bool:
if QuantizationMixin.has_config(self):
QuantizationMixin.initialize_quantization(self, state.model)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# prepare module names
self._module_names = {
m: name
Expand All @@ -171,11 +176,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
)
}

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# register quantization calibration hooks
# assume quantization has been initialized by this modifier or one before it
QuantizationMixin.start_calibration(self, state.model)
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:
)
self.ignore = [] if not self.ignore else self.ignore
self.mappings = self._infer_mappings_from_model(state.model)
self.resolved_mappings_ = self._resolve_mappings(state.model)
self.scales_ = {}

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True
self.resolved_mappings_ = self._resolve_mappings(state.model)
self._setup_scale_hooks()

def on_event(self, state: State, event: Event, **kwargs):
Expand Down
57 changes: 7 additions & 50 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple

import torch
from accelerate.hooks import remove_hook_from_module
from compressed_tensors.utils import (
has_offloaded_params,
offloaded_dispatch,
patch_attr,
remove_dispatch,
)
from compressed_tensors.utils import patch_attr
from compressed_tensors.utils.match import match_targets
from loguru import logger
from torch.fx import Graph, GraphModule, Node
Expand All @@ -37,7 +31,6 @@
"trace_subgraphs",
"Subgraph",
"get_sequential_targets",
"dispatch_for_sequential",
]


Expand Down Expand Up @@ -104,10 +97,9 @@ def trace_subgraphs(
# find modules
targets = match_modules(model, sequential_targets)
ancestors = get_sequential_ancestors(model, targets)
offloaded = set(m for m in model.modules() if has_offloaded_params(m))

# initialize arguments
tracer = SequentialTracer(ancestors, offloaded)
tracer = SequentialTracer(ancestors)
concrete_args = populate_concrete_args(model, sample_input)

with contextlib.ExitStack() as stack:
Expand Down Expand Up @@ -168,32 +160,18 @@ class SequentialTracer(HFTracer):
"""
Get a tracer specialized for the given model. The resulting tracer will not trace
inside of sequential targets, nor any modules which are not call graph ancestors of
sequential targets

Tracing within sequential targets is unnecessary, and tracing within offloaded
modules may result in meta tensors being added to the model graph
sequential targets. Tracing outside of call ancestors of sequential targets will be
skipped

:param ancestors: modules which are ancestors of sequential targets
:param offloaded: modules which have offloaded params and should not be traced
"""

def __init__(self, ancestors: Set[Module], offloaded: Set[Module]):
def __init__(self, ancestors: Set[Module]):
self.ancestors = ancestors
self.offloaded = offloaded

# skip any mask creation functions not already caught by the autowrapper
super().__init__(autowrap_functions=_get_autowrap_functions())

# check unlikely case that ancestors have direct params which are offloaded
offloaded_ancestors = offloaded & ancestors
for ancestor in offloaded_ancestors:
remove_hook_from_module(ancestor, recurse=False)
self.offloaded.remove(ancestor)
logger.warning(
f"Direct parameters attached to {ancestor.__class__.__name__} have "
"been onloaded in order to ensure safe graph capture and execution"
)

def create_arg(self, a: Any) -> Argument:
# special extension allows models which depend on config values to be traced
if isinstance(a, PretrainedConfig):
Expand All @@ -204,8 +182,8 @@ def create_arg(self, a: Any) -> Argument:
return super().create_arg(a)

def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
# do not trace non-ancestors or modules with offloaded params
return module not in self.ancestors or module in self.offloaded
# do not trace non-ancestors
return module not in self.ancestors


def populate_concrete_args(model: Module, sample_input: Dict) -> Dict:
Expand Down Expand Up @@ -526,27 +504,6 @@ def is_ancestor(module: Module) -> bool:
return ancestors


def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
"""
Dispatch a model for sequential calibration using a sequential pipeline.
The model will be offloaded to the CPU and dispatched to CUDA/XPU device
if available. Removes any existing hooks.

:param model: model to dispatch
:return: dispatched model
"""
remove_dispatch(model)

if torch.cuda.is_available():
offloaded_dispatch(model, execution_device=torch.device("cuda:0"))
elif hasattr(torch, "xpu") and torch.xpu.is_available():
offloaded_dispatch(model, execution_device=torch.device("xpu:0"))
else:
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")

return model


def _get_autowrap_functions() -> Tuple[Callable[[Any], Any], ...]:
try:
from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING
Expand Down
13 changes: 7 additions & 6 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import TYPE_CHECKING

import torch
from compressed_tensors.utils import disable_offloading, get_execution_device
from compressed_tensors.offload import offload_model
from compressed_tensors.utils import disable_offloading
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

Expand All @@ -11,10 +12,10 @@
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pipelines.registry import CalibrationPipeline
from llmcompressor.pipelines.sequential.helpers import (
dispatch_for_sequential,
get_sequential_targets,
trace_subgraphs,
)
from llmcompressor.utils.dev import get_main_device
from llmcompressor.utils.helpers import (
DISABLE_QAC_MODIFIERS,
DisableQuantization,
Expand Down Expand Up @@ -59,10 +60,6 @@ def __call__(
"""
session = active_session()

# prepare model for sequential onloading
dispatch_for_sequential(model)
model_device = get_execution_device(model)

# prepare to trace subgraphs
modifiers = session.lifecycle.recipe.modifiers
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
Expand All @@ -73,6 +70,10 @@ def __call__(
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
num_subgraphs = len(subgraphs)

# prepare model for sequential onloading
model_device = get_main_device()
offload_model(model, onload_device=model_device, offload_device="cpu")

LifecycleCallbacks.calibration_epoch_start()

# TODO: remove this to enable quantization aware calibration
Expand Down
50 changes: 27 additions & 23 deletions src/llmcompressor/utils/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging
import os
import tempfile
from functools import wraps
from typing import Type

import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
from compressed_tensors.utils import patch_attr, remove_dispatch
from compressed_tensors.offload import dispatch_model
from compressed_tensors.utils import patch_attr
from huggingface_hub import snapshot_download
from loguru import logger
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
Expand All @@ -17,6 +18,7 @@
__all__ = [
"skip_weights_download",
"patch_transformers_logger_level",
"get_main_device",
"dispatch_for_generation",
]

Expand Down Expand Up @@ -113,28 +115,30 @@ def patch_transformers_logger_level(level: int = logging.ERROR):
transformers_logger.setLevel(level=restore_log_level)


def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
def get_main_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda:0")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu:0")
else:
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
return torch.device("cpu")


@wraps(dispatch_model)
def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel:
"""
Dispatch a model autoregressive generation. This means that modules are dispatched
evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks
that may have existed previously.
evenly across avaiable devices and kept onloaded if possible.

:param model: model to dispatch
:return: model which is dispatched
:param hint_batch_size: reserve memory for batch size of inputs
:param hint_batch_seq_len: reserve memory for sequence of length of inputs
:param hint_model_dtype: reserve memory for model's dtype.
Will be inferred from model if none is provided
:param hint_extra_memory: extra memory reserved for model serving
:param no_split_modules: names of module classes which should not be split
across multiple devices
:return: dispatched model
"""
remove_dispatch(model)

no_split_module_classes = model._get_no_split_modules("auto")
max_memory = get_balanced_memory(
model,
dtype=model.dtype,
no_split_module_classes=no_split_module_classes,
)
device_map = infer_auto_device_map(
model,
dtype=model.dtype,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
)

return dispatch_model(model, device_map=device_map)
return dispatch_model(*args, **kwargs)
12 changes: 3 additions & 9 deletions src/llmcompressor/utils/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from compressed_tensors import has_offloaded_params, register_offload_parameter
from loguru import logger
from torch.nn import Parameter
from transformers import PreTrainedModel
Expand Down Expand Up @@ -28,14 +27,9 @@ def untie_word_embeddings(model: PreTrainedModel):

# clone data to untie
for module in (input_embed, output_embed):
if not has_offloaded_params(module):
data = module.weight.data
else:
data = module._hf_hook.weights_map["weight"]

requires_grad = module.weight.requires_grad
untied_param = Parameter(data.clone(), requires_grad=requires_grad)
register_offload_parameter(module, "weight", untied_param)
weight = module.weight
param = Parameter(weight.data.clone(), requires_grad=weight.requires_grad)
module.register_parameter("weight", param)

# modify model config
if hasattr(model.config, "tie_word_embeddings"):
Expand Down
16 changes: 7 additions & 9 deletions tests/llmcompressor/utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
import torch
from compressed_tensors.offload import dispatch_model
from transformers import (
AutoModelForCausalLM,
MllamaForConditionalGeneration,
)

from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential
from llmcompressor.utils import (
ALL_TOKEN,
DisableQuantization,
Expand All @@ -17,7 +17,7 @@
interpolate,
validate_str_iterable,
)
from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download
from llmcompressor.utils.dev import skip_weights_download
from tests.testing_utils import requires_gpu


Expand Down Expand Up @@ -153,14 +153,12 @@ def test_disable_cache(model_cls, model_stub):


@requires_gpu
@pytest.mark.parametrize("offload", ["sequential", "basic", "none"])
def test_disable_lm_head(offload):
@pytest.mark.parametrize("dispatch", (True, False))
def test_disable_lm_head(dispatch):
model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2")
if offload == "sequential":
dispatch_for_sequential(model)
if offload == "basic":
dispatch_for_generation(model)
if offload == "none":
if dispatch:
dispatch_model(model, "cuda")
else:
model = model.to("cuda")

lm_input_device = None
Expand Down
Loading