From c61dd519181d65818af47e2dd7b7c62445f170c9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 14 Dec 2025 22:45:08 +0000 Subject: [PATCH 1/5] minor changes Signed-off-by: Kyle Sayers --- src/llmcompressor/datasets/utils.py | 4 +-- .../pipelines/sequential/ast_helpers.py | 35 ++++++++++--------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 0d5fceca8..511e0c5ef 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -238,8 +238,8 @@ def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler: def data_collator_with_truncation( - features: list[dict[str, Any]], return_tensors: str = "pt" -) -> dict[str, Any]: + features: list[dict], return_tensors: str = "pt" +) -> dict: for key in ("input_ids", "labels", "attention_mask"): if any(key not in feature for feature in features): continue diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index 65e8d9c61..a27fe4d6a 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -64,20 +64,21 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): # compile new forward function from autowrapped code filename = f"" code = compile(source, filename=filename, mode="exec") - exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap - - # enable better tracebacks if autowrapped code fails - linecache.cache[filename] = ( - len(source), - None, - [line + "\n" for line in source.splitlines()], - filename, - ) - - # patch forward with autowrapped forward - new_forward = namespace["forward"].__get__(module) - with patch_attr(module, "forward", new_forward): - yield + with append_autowrap_source_on_fail(): + exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap + + # enable better tracebacks if autowrapped code fails + linecache.cache[filename] = ( + len(source), + None, + [line + "\n" for line in source.splitlines()], + filename, + ) + + # patch forward with autowrapped forward + new_forward = namespace["forward"].__get__(module) + with patch_attr(module, "forward", new_forward): + yield @contextlib.contextmanager @@ -99,9 +100,9 @@ def append_autowrap_source_on_fail(): for i, line in enumerate(source_lines) ] - message = f"{exception}\n\n" - message += f"\n--- {frame.filename}:{lineno} ---\n" + message = f"--- {frame.filename}:{lineno} ---\n" message += "".join(source_lines) - raise RuntimeError(message) from exception + message += f"\n\n{exception}" + raise RuntimeError(message) raise exception From 1df774d4fa0cc5506485d6ff242e373e93bd6ea3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 14 Dec 2025 22:56:26 +0000 Subject: [PATCH 2/5] add check for forward implementation Signed-off-by: Kyle Sayers --- .../pipelines/sequential/ast_helpers.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index a27fe4d6a..cf23ad229 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -45,6 +45,14 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): :param module: module whose forward method should be replaced :param ignore: explicit list of function names to wrap """ + # check forward method is implemented + if module.forward.__name__ == "_forward_unimplemented": + raise ValueError( + "Cannot calibrate model which does not implement `forward` method. Please " + "either implement a forward method on the model, or pass a submodule to " + "`oneshot`. For example, `oneshot(model.thinker, ...)`" + ) + # get source code of module forward source = inspect.getsource(module.forward) source = textwrap.dedent(source) @@ -75,10 +83,10 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): filename, ) - # patch forward with autowrapped forward - new_forward = namespace["forward"].__get__(module) - with patch_attr(module, "forward", new_forward): - yield + # patch forward with autowrapped forward + new_forward = namespace["forward"].__get__(module) + with patch_attr(module, "forward", new_forward): + yield @contextlib.contextmanager @@ -103,6 +111,6 @@ def append_autowrap_source_on_fail(): message = f"--- {frame.filename}:{lineno} ---\n" message += "".join(source_lines) message += f"\n\n{exception}" - raise RuntimeError(message) + raise RuntimeError(message) from exc_tb raise exception From f70647f4c93fea10c70b79eb6def42f37cd492f2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 14 Dec 2025 18:07:17 -0500 Subject: [PATCH 3/5] Update src/llmcompressor/pipelines/sequential/ast_helpers.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/sequential/ast_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index cf23ad229..b906b027f 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -111,6 +111,6 @@ def append_autowrap_source_on_fail(): message = f"--- {frame.filename}:{lineno} ---\n" message += "".join(source_lines) message += f"\n\n{exception}" - raise RuntimeError(message) from exc_tb + raise RuntimeError(message) from exception raise exception From a07c088c904e8ef5aa339289442a8cfbe33056a5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 14 Dec 2025 18:07:44 -0500 Subject: [PATCH 4/5] Update src/llmcompressor/datasets/utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Kyle Sayers --- src/llmcompressor/datasets/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 511e0c5ef..0d5fceca8 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -238,8 +238,8 @@ def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler: def data_collator_with_truncation( - features: list[dict], return_tensors: str = "pt" -) -> dict: + features: list[dict[str, Any]], return_tensors: str = "pt" +) -> dict[str, Any]: for key in ("input_ids", "labels", "attention_mask"): if any(key not in feature for feature in features): continue From c9cad383780d25665f7de2f0507a820ca40a5161 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 14 Dec 2025 23:08:38 +0000 Subject: [PATCH 5/5] reduce diff Signed-off-by: Kyle Sayers --- .../pipelines/sequential/ast_helpers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index b906b027f..81b33fc0d 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -75,13 +75,13 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): with append_autowrap_source_on_fail(): exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap - # enable better tracebacks if autowrapped code fails - linecache.cache[filename] = ( - len(source), - None, - [line + "\n" for line in source.splitlines()], - filename, - ) + # enable better tracebacks if autowrapped code fails + linecache.cache[filename] = ( + len(source), + None, + [line + "\n" for line in source.splitlines()], + filename, + ) # patch forward with autowrapped forward new_forward = namespace["forward"].__get__(module)