Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,15 @@ def dtype_policy(self, value):
self._dtype_policy = policy
if policy.quantization_mode is not None:
if self.built and not getattr(self, "_is_quantized", False):
if policy.quantization_mode == "gptq":
raise ValueError(
"Implicitly enabling GPTQ quantization by setting "
f"`dtype_policy` to '{value}' is not supported. "
"GPTQ requires a calibration dataset and a "
"`GPTQConfig` object.\n\n"
"Please use the `.quantize('gptq', config=...)` method "
"on the layer or model instead."
)
self.quantize(policy.quantization_mode)

@property
Expand Down
10 changes: 10 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ def test_quantized_layer_with_remat(self):
self.assertLen(mock_remat.rematted_functions, 1)
next(iter(mock_remat.rematted_functions.values())).assert_called()

def test_gptq_quantization_by_setting_dtype(self):
"""Tests error being raised when dtype is set to GPTQ."""
with self.assertRaisesRegex(
ValueError,
"Implicitly enabling GPTQ quantization.*is not supported",
):
layer = layers.Dense(3)
layer.build((2, 4))
layer.dtype_policy = "gptq/4/-1_from_float32"

def test_functional_model_with_remat(self):
if backend.backend() in ("openvino", "numpy"):
self.skipTest(
Expand Down