1+ from typing import Callable
2+
13import torch
24import transformers
35from datasets import load_dataset
46from loguru import logger
5- from transformers import AutoProcessor
7+ from transformers import AutoProcessor , DefaultDataCollator
68
79from llmcompressor import oneshot
810from llmcompressor .modifiers .quantization import GPTQModifier , QuantizationModifier
911from tests .test_timer .timer_utils import log_time
1012from tests .testing_utils import process_dataset
11- from transformers import DefaultDataCollator
1213
1314
1415def load_model (model : str , model_class : str , device_map : str | None = None ):
@@ -35,9 +36,12 @@ def run_oneshot_for_e2e_testing(
3536 dataset_config : str ,
3637 scheme : str ,
3738 quant_type : str ,
39+ shuffle_calibration_samples : bool = True ,
40+ data_collator : str | Callable = DefaultDataCollator (),
3841):
3942 # Load model.
4043 oneshot_kwargs = {}
44+ oneshot_kwargs ["data_collator" ] = data_collator
4145
4246 loaded_model = load_model (model = model , model_class = model_class )
4347 processor = AutoProcessor .from_pretrained (model )
@@ -75,6 +79,7 @@ def data_collator(batch):
7579 oneshot_kwargs ["data_collator" ] = data_collator
7680
7781 oneshot_kwargs ["model" ] = loaded_model
82+ oneshot_kwargs ["shuffle_calibration_samples" ] = shuffle_calibration_samples
7883 if recipe :
7984 oneshot_kwargs ["recipe" ] = recipe
8085 else :
@@ -95,11 +100,8 @@ def data_collator(batch):
95100 )
96101
97102 # Apply quantization.
98-
103+ breakpoint ()
99104 logger .info ("ONESHOT KWARGS" , oneshot_kwargs )
100-
101- oneshot_kwargs ["shuffle_calibration_samples" ] = True
102- oneshot_kwargs ["data_collator" ] = DefaultDataCollator ()
103105 _run_oneshot (** oneshot_kwargs )
104106
105107 return oneshot_kwargs ["model" ], processor
0 commit comments