Skip to content

Commit 11bf7ff

Browse files
committed
clean-up
1 parent 0819192 commit 11bf7ff

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tests/e2e/e2e_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
from typing import Callable
2+
13
import torch
24
import transformers
35
from datasets import load_dataset
46
from loguru import logger
5-
from transformers import AutoProcessor
7+
from transformers import AutoProcessor, DefaultDataCollator
68

79
from llmcompressor import oneshot
810
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
911
from tests.test_timer.timer_utils import log_time
1012
from tests.testing_utils import process_dataset
11-
from transformers import DefaultDataCollator
1213

1314

1415
def load_model(model: str, model_class: str, device_map: str | None = None):
@@ -35,9 +36,12 @@ def run_oneshot_for_e2e_testing(
3536
dataset_config: str,
3637
scheme: str,
3738
quant_type: str,
39+
shuffle_calibration_samples: bool = True,
40+
data_collator: str | Callable = DefaultDataCollator(),
3841
):
3942
# Load model.
4043
oneshot_kwargs = {}
44+
oneshot_kwargs["data_collator"] = data_collator
4145

4246
loaded_model = load_model(model=model, model_class=model_class)
4347
processor = AutoProcessor.from_pretrained(model)
@@ -75,6 +79,7 @@ def data_collator(batch):
7579
oneshot_kwargs["data_collator"] = data_collator
7680

7781
oneshot_kwargs["model"] = loaded_model
82+
oneshot_kwargs["shuffle_calibration_samples"] = shuffle_calibration_samples
7883
if recipe:
7984
oneshot_kwargs["recipe"] = recipe
8085
else:
@@ -95,11 +100,8 @@ def data_collator(batch):
95100
)
96101

97102
# Apply quantization.
98-
103+
breakpoint()
99104
logger.info("ONESHOT KWARGS", oneshot_kwargs)
100-
101-
oneshot_kwargs["shuffle_calibration_samples"] = True
102-
oneshot_kwargs["data_collator"] = DefaultDataCollator()
103105
_run_oneshot(**oneshot_kwargs)
104106

105107
return oneshot_kwargs["model"], processor

0 commit comments

Comments
 (0)