Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/e2e/e2e_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable

import torch
import transformers
from datasets import load_dataset
from loguru import logger
from transformers import AutoProcessor
from transformers import AutoProcessor, DefaultDataCollator

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
Expand Down Expand Up @@ -34,9 +36,12 @@ def run_oneshot_for_e2e_testing(
dataset_config: str,
scheme: str,
quant_type: str,
shuffle_calibration_samples: bool = True,
data_collator: str | Callable = DefaultDataCollator(),
):
# Load model.
oneshot_kwargs = {}
oneshot_kwargs["data_collator"] = data_collator

loaded_model = load_model(model=model, model_class=model_class)
processor = AutoProcessor.from_pretrained(model)
Expand Down Expand Up @@ -74,6 +79,7 @@ def data_collator(batch):
oneshot_kwargs["data_collator"] = data_collator

oneshot_kwargs["model"] = loaded_model
oneshot_kwargs["shuffle_calibration_samples"] = shuffle_calibration_samples
if recipe:
oneshot_kwargs["recipe"] = recipe
else:
Expand Down