Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,13 @@ def dtype_policy(self, value):
self._dtype_policy = policy
if policy.quantization_mode is not None:
if self.built and not getattr(self, "_is_quantized", False):
if policy.quantization_mode == "gptq":
raise ValueError(
f"{value=} enables GPTQ quantization mode."
"This is unsupported since GPTQ requires "
"a calibration dataset and a GPTQConfig."
"Use the `.quantize()` method instead."
)
self.quantize(policy.quantization_mode)

@property
Expand Down
10 changes: 10 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ def test_quantized_layer_with_remat(self):
self.assertLen(mock_remat.rematted_functions, 1)
next(iter(mock_remat.rematted_functions.values())).assert_called()

def test_gptq_quantization_by_setting_dtype(self):
"""Tests error being raised when dtype is set to GPTQ."""
with self.assertRaisesRegex(
ValueError,
"enables GPTQ quantization mode.This is unsupported",
):
layer = layers.Dense(3)
layer.build((2, 4))
layer.dtype_policy = "gptq/4/-1_from_float32"

def test_functional_model_with_remat(self):
if backend.backend() in ("openvino", "numpy"):
self.skipTest(
Expand Down