Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tests/e2e/e2e_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable

import torch
import transformers
from datasets import load_dataset
from loguru import logger
from transformers import AutoProcessor
from transformers import AutoProcessor, DefaultDataCollator

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
Expand Down Expand Up @@ -34,9 +36,12 @@ def run_oneshot_for_e2e_testing(
dataset_config: str,
scheme: str,
quant_type: str,
shuffle_calibration_samples: bool = True,
data_collator: str | Callable = DefaultDataCollator(),
):
# Load model.
oneshot_kwargs = {}
oneshot_kwargs["data_collator"] = data_collator

loaded_model = load_model(model=model, model_class=model_class)
processor = AutoProcessor.from_pretrained(model)
Expand Down Expand Up @@ -74,6 +79,7 @@ def data_collator(batch):
oneshot_kwargs["data_collator"] = data_collator

oneshot_kwargs["model"] = loaded_model
oneshot_kwargs["shuffle_calibration_samples"] = shuffle_calibration_samples
if recipe:
oneshot_kwargs["recipe"] = recipe
else:
Expand All @@ -94,6 +100,7 @@ def data_collator(batch):
)

# Apply quantization.
breakpoint()
logger.info("ONESHOT KWARGS", oneshot_kwargs)
_run_oneshot(**oneshot_kwargs)

Expand Down
Loading