diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index 65e8d9c61..81b33fc0d 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -45,6 +45,14 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): :param module: module whose forward method should be replaced :param ignore: explicit list of function names to wrap """ + # check forward method is implemented + if module.forward.__name__ == "_forward_unimplemented": + raise ValueError( + "Cannot calibrate model which does not implement `forward` method. Please " + "either implement a forward method on the model, or pass a submodule to " + "`oneshot`. For example, `oneshot(model.thinker, ...)`" + ) + # get source code of module forward source = inspect.getsource(module.forward) source = textwrap.dedent(source) @@ -64,7 +72,8 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): # compile new forward function from autowrapped code filename = f"" code = compile(source, filename=filename, mode="exec") - exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap + with append_autowrap_source_on_fail(): + exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap # enable better tracebacks if autowrapped code fails linecache.cache[filename] = ( @@ -99,9 +108,9 @@ def append_autowrap_source_on_fail(): for i, line in enumerate(source_lines) ] - message = f"{exception}\n\n" - message += f"\n--- {frame.filename}:{lineno} ---\n" + message = f"--- {frame.filename}:{lineno} ---\n" message += "".join(source_lines) + message += f"\n\n{exception}" raise RuntimeError(message) from exception raise exception