diff --git a/tests/e2e/e2e_utils.py b/tests/e2e/e2e_utils.py index 765d864cc..05bc3b623 100644 --- a/tests/e2e/e2e_utils.py +++ b/tests/e2e/e2e_utils.py @@ -1,8 +1,10 @@ +from typing import Callable + import torch import transformers from datasets import load_dataset from loguru import logger -from transformers import AutoProcessor +from transformers import AutoProcessor, DefaultDataCollator from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier @@ -34,9 +36,12 @@ def run_oneshot_for_e2e_testing( dataset_config: str, scheme: str, quant_type: str, + shuffle_calibration_samples: bool = True, + data_collator: str | Callable = DefaultDataCollator(), ): # Load model. oneshot_kwargs = {} + oneshot_kwargs["data_collator"] = data_collator loaded_model = load_model(model=model, model_class=model_class) processor = AutoProcessor.from_pretrained(model) @@ -74,6 +79,7 @@ def data_collator(batch): oneshot_kwargs["data_collator"] = data_collator oneshot_kwargs["model"] = loaded_model + oneshot_kwargs["shuffle_calibration_samples"] = shuffle_calibration_samples if recipe: oneshot_kwargs["recipe"] = recipe else: