From 00bad8242194b9284676b0a4fbbcc2cd5fe83759 Mon Sep 17 00:00:00 2001 From: Brandon Groth Date: Thu, 26 Jun 2025 20:03:26 -0400 Subject: [PATCH 1/2] fix: Fixed recipe being overwritten in qconfig_init Signed-off-by: Brandon Groth --- fms_mo/utils/qconfig_utils.py | 57 ++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index e2c13355..c8e9a093 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -713,36 +713,43 @@ def remove_unwanted_from_config( return config, dump -def get_unwanted_defaults() -> dict: +def get_unserializable_defaults() -> dict: """Add back those unserializable items if needed""" - unwanted_items = [ - ("sweep_cv_percentile", False), - ("tb_writer", None), - ( - "mapping", - { - nn.Conv2d: QConv2d, - nn.ConvTranspose2d: QConvTranspose2d, - nn.Linear: QLinear, - nn.LSTM: QLSTM, - "matmul_or_bmm": QBmm, - }, - ), - ("checkQerr_frequency", False), - ("newlySwappedModules", []), - ("force_calib_once", False), + unserializable_items = { + "sweep_cv_percentile": False, + "tb_writer": None, + "mapping": { + nn.Conv2d: QConv2d, + nn.ConvTranspose2d: QConvTranspose2d, + nn.Linear: QLinear, + nn.LSTM: QLSTM, + "matmul_or_bmm": QBmm, + }, + "checkQerr_frequency": False, + "newlySwappedModules": [], + "force_calib_once": False, # if we keep the follwing LUTs, it will save the entire model - ("LUTmodule_name", {}), - ] - return unwanted_items + "LUTmodule_name": {}, + } + return unserializable_items + + +def add_if_not_present(config: dict, items_to_add: dict) -> None: + """ + Add items to config dict only if they aren't present + + Args: + config (dict): Quantized config + items_to_add (dict): Items that will be added if not present in config + """ + for key, val in items_to_add.items(): + if key not in config: + config[key] = val def add_required_defaults_to_config(config: dict) -> None: """Recover "unserializable" items that are previously removed from config""" - unwanted_items = get_unwanted_defaults() - for key, default_val in unwanted_items: - if key not in config: - config[key] = default_val + add_if_not_present(config, get_unserializable_defaults()) def add_wanted_defaults_to_config(config: dict, minimal: bool = True) -> None: @@ -750,7 +757,7 @@ def add_wanted_defaults_to_config(config: dict, minimal: bool = True) -> None: if a wanted item is not in the config, add it w/ default value """ if not minimal: - config.update(config_defaults()) + add_if_not_present(config, config_defaults()) def qconfig_save( From a6bd15aace8338ebf03e53343941b24e7a3d16bf Mon Sep 17 00:00:00 2001 From: Brandon Groth Date: Thu, 26 Jun 2025 20:05:32 -0400 Subject: [PATCH 2/2] test: Added qcfg recipe test to for qconfig_save + qconfig_init Signed-off-by: Brandon Groth --- tests/models/test_saveconfig.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/models/test_saveconfig.py b/tests/models/test_saveconfig.py index 6596b2bd..1fac4ce3 100644 --- a/tests/models/test_saveconfig.py +++ b/tests/models/test_saveconfig.py @@ -20,6 +20,7 @@ import pytest # Local +from fms_mo import qconfig_init from fms_mo.utils.qconfig_utils import qconfig_load, qconfig_save from tests.models.test_model_utils import ( delete_file, @@ -298,3 +299,31 @@ def test_load_config_required_pair( loaded_config = qconfig_load("qcfg.json") assert loaded_config.get(key) == default_val + + +def test_save_init_recipe( + config_int8: dict, +): + """ + Change a config, save it, + + Args: + config_fp32 (dict): Config for fp32 quantization + """ + # Change some elements of config to ensure its being saved/loaded properly + config_int8["qa_mode"] = "minmax" + config_int8["qa_mode"] = "pertokenmax" + config_int8["qmodel_calibration"] = 17 + config_int8["qskip_layer_name"] = ["lm_head"] + + qconfig_save(config_int8) + recipe_config = qconfig_init(recipe="qcfg.json") + + # Remove date field from recipe_config - only added at save + del recipe_config["date"] + + assert len(recipe_config) == len(config_int8) + + for key, val in config_int8.items(): + assert key in recipe_config + assert recipe_config.get(key) == val