From 5fb142417d907a7fb9bcb9879a25da1f1c0582f0 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Sun, 4 Jan 2026 00:25:58 -0500 Subject: [PATCH 01/15] [WIP] refactor: move proxy_wrapper to within the instrumentor --- .github/workflows/eval-overhead-e2e.yml | 2 -- traincheck/collect_trace.py | 15 +-------------- traincheck/instrumentor/dumper.py | 6 +++--- .../proxy_wrapper/Changelog.md | 2 +- .../{ => instrumentor}/proxy_wrapper/README.md | 0 .../{ => instrumentor}/proxy_wrapper/__init__.py | 4 ++-- .../{ => instrumentor}/proxy_wrapper/dumper.py | 4 ++-- .../{ => instrumentor}/proxy_wrapper/hash.py | 0 .../{ => instrumentor}/proxy_wrapper/proxy.py | 6 +++--- .../proxy_wrapper/proxy_basics.py | 2 +- .../proxy_wrapper/proxy_config.py | 0 .../proxy_wrapper/proxy_handler.py | 0 .../proxy_wrapper/proxy_methods.py | 0 .../proxy_wrapper/proxy_observer.py | 6 +++--- .../proxy_wrapper/proxy_registry.py | 0 .../{ => instrumentor}/proxy_wrapper/subclass.py | 2 +- .../proxy_wrapper/torch_proxy.py | 2 +- .../{ => instrumentor}/proxy_wrapper/utils.py | 2 +- traincheck/instrumentor/replace_functions.py | 2 +- traincheck/instrumentor/source_file.py | 12 ++++++------ traincheck/instrumentor/tracer.py | 16 ++++++++-------- .../graph_generator/call_graph_parser.py | 2 +- 22 files changed, 35 insertions(+), 50 deletions(-) rename traincheck/{ => instrumentor}/proxy_wrapper/Changelog.md (99%) rename traincheck/{ => instrumentor}/proxy_wrapper/README.md (100%) rename traincheck/{ => instrumentor}/proxy_wrapper/__init__.py (63%) rename traincheck/{ => instrumentor}/proxy_wrapper/dumper.py (93%) rename traincheck/{ => instrumentor}/proxy_wrapper/hash.py (100%) rename traincheck/{ => instrumentor}/proxy_wrapper/proxy.py (98%) rename traincheck/{ => instrumentor}/proxy_wrapper/proxy_basics.py (98%) rename traincheck/{ => instrumentor}/proxy_wrapper/proxy_config.py (100%) rename traincheck/{ => instrumentor}/proxy_wrapper/proxy_handler.py (100%) rename traincheck/{ => instrumentor}/proxy_wrapper/proxy_methods.py (100%) rename traincheck/{ => instrumentor}/proxy_wrapper/proxy_observer.py (89%) rename traincheck/{ => instrumentor}/proxy_wrapper/proxy_registry.py (100%) rename traincheck/{ => instrumentor}/proxy_wrapper/subclass.py (98%) rename traincheck/{ => instrumentor}/proxy_wrapper/torch_proxy.py (92%) rename traincheck/{ => instrumentor}/proxy_wrapper/utils.py (52%) diff --git a/.github/workflows/eval-overhead-e2e.yml b/.github/workflows/eval-overhead-e2e.yml index 90f88b0f..acdfa12a 100644 --- a/.github/workflows/eval-overhead-e2e.yml +++ b/.github/workflows/eval-overhead-e2e.yml @@ -6,13 +6,11 @@ on: paths: - '.github/workflows/**' - 'traincheck/instrumentor/**' - - 'traincheck/proxy_wrapper/**' - 'traincheck/collect_trace.py' pull_request: paths: - '.github/workflows/**' - 'traincheck/instrumentor/**' - - 'traincheck/proxy_wrapper/**' - 'traincheck/collect_trace.py' diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 854735f4..4f9eb7e6 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -7,7 +7,7 @@ import traincheck.config.config as config import traincheck.instrumentor as instrumentor -import traincheck.proxy_wrapper.proxy_config as proxy_config +import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config import traincheck.runner as runner from traincheck.config.config import InstrOpt from traincheck.invariant.base_cls import ( @@ -157,19 +157,6 @@ def get_model_tracker_instr_opts(invariants: list[Invariant]) -> str | None: return tracker_type -def get_disable_proxy_dumping(invariants: list[Invariant]) -> bool: - """ - Get disable proxy dumping options for checking - - Always return True if an APIContain invariant requested proxy tracking - - We cannot disable automatic variable dumping if only consistency relations but no APIContain - require variable states, as then no APIs will trigger state dumps. - However, the var tracker should be sampler if there's no APIContain anyway - """ - return True - - def dump_env(args, output_dir: str): with open(os.path.join(output_dir, "env_dump.txt"), "w") as f: f.write("Arguments:\n") diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index 04935e8a..cbc745f5 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -17,9 +17,9 @@ ) # if torch.cuda.is_available(): -from traincheck.proxy_wrapper.hash import tensor_hash -from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor -from traincheck.proxy_wrapper.proxy_config import ( +from traincheck.instrumentor.proxy_wrapper.hash import tensor_hash +from traincheck.instrumentor.proxy_wrapper.proxy_basics import is_fake_tensor +from traincheck.instrumentor.proxy_wrapper.proxy_config import ( attribute_black_list, primitive_types, proxy_attribute, diff --git a/traincheck/proxy_wrapper/Changelog.md b/traincheck/instrumentor/proxy_wrapper/Changelog.md similarity index 99% rename from traincheck/proxy_wrapper/Changelog.md rename to traincheck/instrumentor/proxy_wrapper/Changelog.md index ecc55135..a90152a1 100644 --- a/traincheck/proxy_wrapper/Changelog.md +++ b/traincheck/instrumentor/proxy_wrapper/Changelog.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning]. ### Added -- Maintain global registry to proxied objects (to access the vars, use `from traincheck.proxy_wrapper.proxy import get_registered_object`) +- Maintain global registry to proxied objects (to access the vars, use `from traincheck.instrumentor.proxy_wrapper.proxy import get_registered_object`) - Bypass tensor stats/hash computation if it has already been calculated ### Fixed diff --git a/traincheck/proxy_wrapper/README.md b/traincheck/instrumentor/proxy_wrapper/README.md similarity index 100% rename from traincheck/proxy_wrapper/README.md rename to traincheck/instrumentor/proxy_wrapper/README.md diff --git a/traincheck/proxy_wrapper/__init__.py b/traincheck/instrumentor/proxy_wrapper/__init__.py similarity index 63% rename from traincheck/proxy_wrapper/__init__.py rename to traincheck/instrumentor/proxy_wrapper/__init__.py index d3bf780a..de648b4d 100644 --- a/traincheck/proxy_wrapper/__init__.py +++ b/traincheck/instrumentor/proxy_wrapper/__init__.py @@ -1,4 +1,4 @@ # This import is necessary to make the observer utility inside torch_proxy.py executed before the instrumented code. This would ensure the observer function is successfully registred before the instrumented code is executed. -import traincheck.proxy_wrapper.proxy_config # noqa -import traincheck.proxy_wrapper.torch_proxy # noqa +import traincheck.instrumentor.proxy_wrapper.proxy_config # noqa +import traincheck.instrumentor.proxy_wrapper.torch_proxy # noqa diff --git a/traincheck/proxy_wrapper/dumper.py b/traincheck/instrumentor/proxy_wrapper/dumper.py similarity index 93% rename from traincheck/proxy_wrapper/dumper.py rename to traincheck/instrumentor/proxy_wrapper/dumper.py index cd7cf3d9..242a9ece 100644 --- a/traincheck/proxy_wrapper/dumper.py +++ b/traincheck/instrumentor/proxy_wrapper/dumper.py @@ -2,10 +2,10 @@ from typing import Dict from traincheck.instrumentor.dumper import convert_var_to_dict +from traincheck.instrumentor.proxy_wrapper.proxy_basics import is_proxied +from traincheck.instrumentor.proxy_wrapper.proxy_config import primitive_types from traincheck.instrumentor.tracer import TraceLineType from traincheck.instrumentor.tracer import get_meta_vars as tracer_get_meta_vars -from traincheck.proxy_wrapper.proxy_basics import is_proxied -from traincheck.proxy_wrapper.proxy_config import primitive_types class Singleton(type): diff --git a/traincheck/proxy_wrapper/hash.py b/traincheck/instrumentor/proxy_wrapper/hash.py similarity index 100% rename from traincheck/proxy_wrapper/hash.py rename to traincheck/instrumentor/proxy_wrapper/hash.py diff --git a/traincheck/proxy_wrapper/proxy.py b/traincheck/instrumentor/proxy_wrapper/proxy.py similarity index 98% rename from traincheck/proxy_wrapper/proxy.py rename to traincheck/instrumentor/proxy_wrapper/proxy.py index 68a839d9..f9deb9ee 100644 --- a/traincheck/proxy_wrapper/proxy.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy.py @@ -9,9 +9,9 @@ import torch import traincheck.config.config as general_config -import traincheck.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables -import traincheck.proxy_wrapper.proxy_methods as proxy_methods -from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars +import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables +import traincheck.instrumentor.proxy_wrapper.proxy_methods as proxy_methods +from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars from traincheck.utils import get_timestamp_ns, typename from .dumper import json_dumper as dumper diff --git a/traincheck/proxy_wrapper/proxy_basics.py b/traincheck/instrumentor/proxy_wrapper/proxy_basics.py similarity index 98% rename from traincheck/proxy_wrapper/proxy_basics.py rename to traincheck/instrumentor/proxy_wrapper/proxy_basics.py index dd3014bb..9fd299d4 100644 --- a/traincheck/proxy_wrapper/proxy_basics.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_basics.py @@ -123,7 +123,7 @@ def visit_FunctionDef(self, node): self.generic_visit(node) # Inject code right after the def statement inject_code = """ -from traincheck.proxy_wrapper.proxy_basics import type_handle_traincheck_proxy +from traincheck.instrumentor.proxy_wrapper.proxy_basics import type_handle_traincheck_proxy """ inject_node = ast.parse(inject_code).body node.body = inject_node + node.body diff --git a/traincheck/proxy_wrapper/proxy_config.py b/traincheck/instrumentor/proxy_wrapper/proxy_config.py similarity index 100% rename from traincheck/proxy_wrapper/proxy_config.py rename to traincheck/instrumentor/proxy_wrapper/proxy_config.py diff --git a/traincheck/proxy_wrapper/proxy_handler.py b/traincheck/instrumentor/proxy_wrapper/proxy_handler.py similarity index 100% rename from traincheck/proxy_wrapper/proxy_handler.py rename to traincheck/instrumentor/proxy_wrapper/proxy_handler.py diff --git a/traincheck/proxy_wrapper/proxy_methods.py b/traincheck/instrumentor/proxy_wrapper/proxy_methods.py similarity index 100% rename from traincheck/proxy_wrapper/proxy_methods.py rename to traincheck/instrumentor/proxy_wrapper/proxy_methods.py diff --git a/traincheck/proxy_wrapper/proxy_observer.py b/traincheck/instrumentor/proxy_wrapper/proxy_observer.py similarity index 89% rename from traincheck/proxy_wrapper/proxy_observer.py rename to traincheck/instrumentor/proxy_wrapper/proxy_observer.py index 5316fed6..3dce191b 100644 --- a/traincheck/proxy_wrapper/proxy_observer.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_observer.py @@ -2,12 +2,12 @@ import typing from traincheck.config.config import should_disable_proxy_dumping -from traincheck.proxy_wrapper.subclass import ProxyParameter +from traincheck.instrumentor.proxy_wrapper.subclass import ProxyParameter from traincheck.utils import typename if typing.TYPE_CHECKING: - from traincheck.proxy_wrapper.proxy import Proxy - from traincheck.proxy_wrapper.subclass import ProxyParameter + from traincheck.instrumentor.proxy_wrapper.proxy import Proxy + from traincheck.instrumentor.proxy_wrapper.subclass import ProxyParameter from .proxy_basics import is_proxied, is_proxyparameter, unproxy_func diff --git a/traincheck/proxy_wrapper/proxy_registry.py b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py similarity index 100% rename from traincheck/proxy_wrapper/proxy_registry.py rename to traincheck/instrumentor/proxy_wrapper/proxy_registry.py diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/instrumentor/proxy_wrapper/subclass.py similarity index 98% rename from traincheck/proxy_wrapper/subclass.py rename to traincheck/instrumentor/proxy_wrapper/subclass.py index 19a0ebf0..4de68009 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/instrumentor/proxy_wrapper/subclass.py @@ -7,8 +7,8 @@ from traincheck.config.config import should_disable_proxy_dumping from traincheck.instrumentor.dumper import dump_trace_VAR +from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars from traincheck.instrumentor.tracer import TraceLineType -from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars from traincheck.utils import get_timestamp_ns from .proxy_basics import is_fake_tensor diff --git a/traincheck/proxy_wrapper/torch_proxy.py b/traincheck/instrumentor/proxy_wrapper/torch_proxy.py similarity index 92% rename from traincheck/proxy_wrapper/torch_proxy.py rename to traincheck/instrumentor/proxy_wrapper/torch_proxy.py index 09e14736..241c6800 100644 --- a/traincheck/proxy_wrapper/torch_proxy.py +++ b/traincheck/instrumentor/proxy_wrapper/torch_proxy.py @@ -6,7 +6,7 @@ pass from torch._C._distributed_c10d import ProcessGroup -from traincheck.proxy_wrapper.proxy_basics import unproxy_func +from traincheck.instrumentor.proxy_wrapper.proxy_basics import unproxy_func ################################################# ### Proxied Torch functions diff --git a/traincheck/proxy_wrapper/utils.py b/traincheck/instrumentor/proxy_wrapper/utils.py similarity index 52% rename from traincheck/proxy_wrapper/utils.py rename to traincheck/instrumentor/proxy_wrapper/utils.py index e736d360..48f6cd8b 100644 --- a/traincheck/proxy_wrapper/utils.py +++ b/traincheck/instrumentor/proxy_wrapper/utils.py @@ -1,4 +1,4 @@ -from traincheck.proxy_wrapper.proxy_config import debug_mode +from traincheck.instrumentor.proxy_wrapper.proxy_config import debug_mode def print_debug(message_func): diff --git a/traincheck/instrumentor/replace_functions.py b/traincheck/instrumentor/replace_functions.py index 2d343519..0fe4980b 100644 --- a/traincheck/instrumentor/replace_functions.py +++ b/traincheck/instrumentor/replace_functions.py @@ -2,7 +2,7 @@ import torch.optim.optimizer as optimizer_ -from traincheck.proxy_wrapper.proxy_basics import adapt_func_for_proxy +from traincheck.instrumentor.proxy_wrapper.proxy_basics import adapt_func_for_proxy from traincheck.utils import typename diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index d0279fff..47c0ab4c 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -367,34 +367,34 @@ def instrument_model_tracker_proxy( if proxy_basic_config: if "proxy_log_dir" not in proxy_basic_config: - from traincheck.proxy_wrapper.proxy_config import proxy_log_dir + from traincheck.instrumentor.proxy_wrapper.proxy_config import proxy_log_dir proxy_basic_config["proxy_log_dir"] = proxy_log_dir proxy_start_code += f""" -import traincheck.proxy_wrapper.proxy_config as proxy_config +import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config proxy_config.__dict__.update({proxy_basic_config}) """ if tensor_dump_format: proxy_start_code += f""" -from traincheck.proxy_wrapper.proxy_config import tensor_dump_format +from traincheck.instrumentor.proxy_wrapper.proxy_config import tensor_dump_format tensor_dump_format.update({tensor_dump_format}) """ if model_tracker_style == "proxy": proxy_start_code += """ -from traincheck.proxy_wrapper.proxy import Proxy +from traincheck.instrumentor.proxy_wrapper.proxy import Proxy """ else: proxy_start_code += """ -from traincheck.proxy_wrapper.subclass import proxy_parameter +from traincheck.instrumentor.proxy_wrapper.subclass import proxy_parameter """ if auto_observer_config["enable_auto_observer"]: auto_observer_code = """ import glob import importlib -from traincheck.proxy_wrapper.proxy_config import auto_observer_config +from traincheck.instrumentor.proxy_wrapper.proxy_config import auto_observer_config spec = importlib.util.find_spec('traincheck') if spec and spec.origin: traincheck_folder = os.path.dirname(spec.origin) diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index feb63f4f..d518a4f8 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -25,17 +25,17 @@ get_instrumentation_logger_for_process, var_to_serializable, ) -from traincheck.instrumentor.replace_functions import ( - funcs_to_be_replaced, - is_funcs_to_be_unproxied, -) -from traincheck.proxy_wrapper.proxy_basics import ( +from traincheck.instrumentor.proxy_wrapper.proxy_basics import ( is_proxied, is_proxyparameter, unproxy_func, ) -from traincheck.proxy_wrapper.proxy_config import enable_C_level_observer -from traincheck.proxy_wrapper.proxy_registry import get_global_registry +from traincheck.instrumentor.proxy_wrapper.proxy_config import enable_C_level_observer +from traincheck.instrumentor.proxy_wrapper.proxy_registry import get_global_registry +from traincheck.instrumentor.replace_functions import ( + funcs_to_be_replaced, + is_funcs_to_be_unproxied, +) from traincheck.utils import get_timestamp_ns, get_unique_id, typename _instancemethod_t = type(torch._C._distributed_c10d.ProcessGroup.broadcast) @@ -464,7 +464,7 @@ def find_proxy_in_args(args): if handle_proxy: if enable_C_level_observer and is_builtin: - from traincheck.proxy_wrapper.proxy_observer import ( + from traincheck.instrumentor.proxy_wrapper.proxy_observer import ( add_observer_to_func, # import here to avoid circular import ) diff --git a/traincheck/static_analyzer/graph_generator/call_graph_parser.py b/traincheck/static_analyzer/graph_generator/call_graph_parser.py index 1b66be2a..5ce9d258 100644 --- a/traincheck/static_analyzer/graph_generator/call_graph_parser.py +++ b/traincheck/static_analyzer/graph_generator/call_graph_parser.py @@ -4,7 +4,7 @@ import os import re -from traincheck.proxy_wrapper.proxy_observer import add_observer_to_func +from traincheck.instrumentor.proxy_wrapper.proxy_observer import add_observer_to_func def unparse_module(module_name, level=0): From 0c0ebf8cf5612fd73af556e008b26f4db6a3525f Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Sun, 4 Jan 2026 18:29:03 -0500 Subject: [PATCH 02/15] WIP: strengthen instrumentation logic --- docs/assets/code/mnist.py | 6 +- .../traincheck-collect/mnist-config/mnist.py | 6 +- traincheck/developer/annotations.py | 4 +- traincheck/instrumentor/VFProxy.py | 60 --- traincheck/instrumentor/__init__.py | 2 +- traincheck/instrumentor/caches.py | 7 +- .../proxy_wrapper/proxy_basics.py | 15 +- .../proxy_wrapper/proxy_observer.py | 2 +- .../proxy_wrapper/proxy_registry.py | 22 +- .../instrumentor/proxy_wrapper/subclass.py | 6 +- traincheck/instrumentor/tracer.py | 351 +++--------------- 11 files changed, 87 insertions(+), 394 deletions(-) delete mode 100644 traincheck/instrumentor/VFProxy.py diff --git a/docs/assets/code/mnist.py b/docs/assets/code/mnist.py index 58c0254a..6952ed8a 100644 --- a/docs/assets/code/mnist.py +++ b/docs/assets/code/mnist.py @@ -8,9 +8,9 @@ from torchvision import datasets, transforms from traincheck import annotate_stage -from traincheck.instrumentor import meta_vars +from traincheck.instrumentor import META_VARS -meta_vars["step"] = -1 +META_VARS["step"] = -1 class Net(nn.Module): @@ -43,7 +43,7 @@ def train(args, model, device, train_loader, optimizer, epoch): annotate_stage("training") # ML_DAIKON: stage annotation model.train() for batch_idx, (data, target) in enumerate(train_loader): - meta_vars["step"] += 1 + META_VARS["step"] += 1 data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) diff --git a/docs/assets/examples/traincheck-collect/mnist-config/mnist.py b/docs/assets/examples/traincheck-collect/mnist-config/mnist.py index 58c0254a..6952ed8a 100644 --- a/docs/assets/examples/traincheck-collect/mnist-config/mnist.py +++ b/docs/assets/examples/traincheck-collect/mnist-config/mnist.py @@ -8,9 +8,9 @@ from torchvision import datasets, transforms from traincheck import annotate_stage -from traincheck.instrumentor import meta_vars +from traincheck.instrumentor import META_VARS -meta_vars["step"] = -1 +META_VARS["step"] = -1 class Net(nn.Module): @@ -43,7 +43,7 @@ def train(args, model, device, train_loader, optimizer, epoch): annotate_stage("training") # ML_DAIKON: stage annotation model.train() for batch_idx, (data, target) in enumerate(train_loader): - meta_vars["step"] += 1 + META_VARS["step"] += 1 data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) diff --git a/traincheck/developer/annotations.py b/traincheck/developer/annotations.py index b89b1827..ccc6b0a5 100644 --- a/traincheck/developer/annotations.py +++ b/traincheck/developer/annotations.py @@ -1,6 +1,6 @@ import traincheck.instrumentor.tracer as tracer from traincheck.config.config import ALL_STAGE_NAMES -from traincheck.instrumentor import meta_vars +from traincheck.instrumentor import META_VARS def annotate_stage(stage_name: str): @@ -16,7 +16,7 @@ def annotate_stage(stage_name: str): stage_name in ALL_STAGE_NAMES ), f"Invalid stage name: {stage_name}, valid ones are {ALL_STAGE_NAMES}" - meta_vars["stage"] = stage_name + META_VARS["stage"] = stage_name def annotate_answer_start_token_ids( diff --git a/traincheck/instrumentor/VFProxy.py b/traincheck/instrumentor/VFProxy.py deleted file mode 100644 index aeb9bb72..00000000 --- a/traincheck/instrumentor/VFProxy.py +++ /dev/null @@ -1,60 +0,0 @@ -import functools - - -def is_proxied(obj): - try: - if obj is not None and "is_traincheck_proxied_obj" in obj.__dict__: - return True - except Exception: - return False - return False - - -def unproxy_arg(arg, inspect_torch_module=False): - - if is_proxied(arg): - return unproxy_arg(arg._obj, inspect_torch_module) - elif type(arg) in [list]: - return [unproxy_arg(element, inspect_torch_module) for element in arg] - elif type(arg) in [tuple]: - return tuple(unproxy_arg(element, inspect_torch_module) for element in arg) - # if it is a torch module, unproxy all its named children - elif inspect_torch_module: - import torch - - if isinstance(arg, torch.nn.Module): - for name, module in arg.named_children(): - arg._modules[name] = unproxy_arg(module, inspect_torch_module) - # handle named_parameters - for name, param in arg.named_parameters(): - arg._parameters[name] = unproxy_arg(param, inspect_torch_module) - return arg - - return arg - else: - return arg - - -# Proxy class to wrap the torch._VF module -class VFProxy: - def __init__(self, vf_module): - self._vf_module = vf_module - - def __getattr__(self, name): - attr = getattr(self._vf_module, name) - - if callable(attr): - return self.unproxy_func(attr) - else: - return attr - - def unproxy_func(self, func, inspect_torch_module=False): - original_func = func - - @functools.wraps(original_func) - def wrapper(*args, **kwargs): - args = [unproxy_arg(arg, inspect_torch_module) for arg in args] - kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()} - return original_func(*args, **kwargs) - - return wrapper diff --git a/traincheck/instrumentor/__init__.py b/traincheck/instrumentor/__init__.py index 7e6012e8..fb5ae22c 100644 --- a/traincheck/instrumentor/__init__.py +++ b/traincheck/instrumentor/__init__.py @@ -1,4 +1,4 @@ -from .caches import meta_vars # noqa: F401 +from .caches import META_VARS # noqa: F401 from .source_file import * # noqa: F403 from .tracer import * # noqa: F403 from .tracer import VarSampler # noqa: F401 diff --git a/traincheck/instrumentor/caches.py b/traincheck/instrumentor/caches.py index cf07b195..288525c0 100644 --- a/traincheck/instrumentor/caches.py +++ b/traincheck/instrumentor/caches.py @@ -1,8 +1,3 @@ -from collections import defaultdict - -from traincheck.instrumentor.types import PTID - -cache_meta_vars: dict[PTID, dict[str, dict]] = defaultdict(lambda: defaultdict(dict)) -meta_vars: dict[str, object] = { +META_VARS: dict[str, object] = { "step": 0, } diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_basics.py b/traincheck/instrumentor/proxy_wrapper/proxy_basics.py index 9fd299d4..a1cc1aab 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_basics.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_basics.py @@ -93,24 +93,23 @@ def unproxy_arg(arg, inspect_torch_module=False): return arg +def unproxy_args_kwargs(args, kwargs, inspect_torch_module=False): + args = [unproxy_arg(arg, inspect_torch_module) for arg in args] + kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()} + return args, kwargs + + def unproxy_func(func, inspect_torch_module=False): original_func = func @functools.wraps(original_func) def wrapper(*args, **kwargs): - args = [unproxy_arg(arg, inspect_torch_module) for arg in args] - kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()} + args, kwargs = unproxy_args_kwargs(args, kwargs, inspect_torch_module) return original_func(*args, **kwargs) return wrapper -def unproxy_args_kwargs(args, kwargs, inspect_torch_module=False): - args = [unproxy_arg(arg, inspect_torch_module) for arg in args] - kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()} - return args, kwargs - - def type_handle_traincheck_proxy(x): if hasattr(x, "is_traincheck_proxied_obj"): return type(x._obj) diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_observer.py b/traincheck/instrumentor/proxy_wrapper/proxy_observer.py index 3dce191b..a333000a 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_observer.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_observer.py @@ -61,7 +61,7 @@ def wrapper(*args, **kwargs): result = processed_function(*args, **kwargs) # post observe - for i, var in enumerate(proxied_vars): + for var in proxied_vars: observe_proxy_var( var, "post_observe", diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py index 83954cac..b8fbde11 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py @@ -15,11 +15,15 @@ def __init__(self, proxy: "Proxy", stale: bool): self.stale = stale -class ProxyRegistry: - """A helper class managing all proxy variables being tracked and allow for controlled dumps of +class VarRegistry: + """A helper class managing all variables being tracked and allow for controlled dumps of the variable states. A variable is uniquely identified by its "name" + When a variable is added to the registry, it is marked as "not stale". + When a variable is dumped through `dump_sample` or `dump_modified`, it is marked as "stale". + A variable is only dumped through `dump_modified` if it is not stale. + """ def __init__(self): @@ -29,20 +33,24 @@ def __init__(self): def add_var(self, var: "Proxy", var_name: str): """Add a new proxy variable to the registry""" with self.registry_lock: - self.registry[var_name] = RegistryEntry(proxy=var, stale=False) + if var_name in self.registry: + self.registry[var_name].proxy = var + self.registry[var_name].stale = False + else: + self.registry[var_name] = RegistryEntry(proxy=var, stale=False) def dump_sample(self, dump_loc=None): """A complete dump of all present proxy objects Calling this API mark all proxy objects as stale which - will affect the `dump_only_modified` API. + will affect the `dump_modified` API. """ with self.registry_lock: - for var_name, entry in self.registry.items(): + for _, entry in self.registry.items(): entry.stale = True entry.proxy.dump_trace(phase="sample", dump_loc=dump_loc) - def dump_only_modified(self, dump_loc=None, dump_config=None): + def dump_modified(self, dump_loc=None, dump_config=None): """Dump only the proxy variables that might be modified since last dump args: @@ -81,7 +89,7 @@ def dump_only_modified(self, dump_loc=None, dump_config=None): # Global dictionary to store registered objects -global_registry = ProxyRegistry() +global_registry = VarRegistry() def get_global_registry(): diff --git a/traincheck/instrumentor/proxy_wrapper/subclass.py b/traincheck/instrumentor/proxy_wrapper/subclass.py index 4de68009..c9fa3973 100644 --- a/traincheck/instrumentor/proxy_wrapper/subclass.py +++ b/traincheck/instrumentor/proxy_wrapper/subclass.py @@ -12,8 +12,8 @@ from traincheck.utils import get_timestamp_ns from .proxy_basics import is_fake_tensor +from .proxy_registry import get_global_registry -# from .proxy_registry import get_global_registry # from .utils import print_debug @@ -165,9 +165,9 @@ def update_timestamp(self): # Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time def register_object(self): - # get_global_registry().add_var(self, self.__dict__["var_name"]) + get_global_registry().add_var(self, self.__dict__["var_name"]) # TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object - pass + # pass def dump_trace(self, phase, dump_loc): # print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index d518a4f8..76fc7db1 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -17,7 +17,7 @@ WRAP_WITHOUT_DUMP, WRAP_WITHOUT_DUMP_WHITELIST, ) -from traincheck.instrumentor.caches import meta_vars +from traincheck.instrumentor.caches import META_VARS from traincheck.instrumentor.dumper import ( convert_var_to_dict, dump_trace_API, @@ -28,7 +28,7 @@ from traincheck.instrumentor.proxy_wrapper.proxy_basics import ( is_proxied, is_proxyparameter, - unproxy_func, + unproxy_args_kwargs, ) from traincheck.instrumentor.proxy_wrapper.proxy_config import enable_C_level_observer from traincheck.instrumentor.proxy_wrapper.proxy_registry import get_global_registry @@ -81,24 +81,7 @@ def is_c_level_function(original_function): def get_meta_vars() -> dict: """Deprecated: use meta_vars directly""" - return meta_vars - - -def increment_step_if_needed(func_obj, func_name, is_bound_method, args): - """Increment the global step if - - the function is torch.optim.Optimizer.step""" - if not is_bound_method: - return - - obj = args[0] - - if func_name.endswith(".step"): - # if the function is a bound method and the object is an instance of torch.optim.Optimizer - if isinstance(obj, torch.optim.Optimizer): - meta_vars[ - "step" - ] += 1 # TODO: what if the users have annotated their own step function? - return True + return META_VARS def to_dict_args_kwargs(args, kwargs, dump_args_config=None) -> dict: @@ -140,224 +123,7 @@ def to_dict_return_value(result) -> dict | list[dict]: return result_dict -def global_wrapper_subclass( - original_function: Callable, - original_function_name: str, - is_bound_method: bool, - scan_proxy_in_args: bool, - dump_stack_trace: bool, - dump_args: bool, - dump_args_config, - dump_ret: bool, - dump_ret_config, - *args, - **kwargs, -): - """Instrumentation for APIs - - Pre-call Phase - 1. Log the pre-call information - - Call Phase - 1. Calls the original function - 2. If an exception is raised, log the exception and re-raise it - - Post-call Phase - 1. Log the post-call information - """ - - global DISABLE_WRAPPER - global PROCESS_ID - - if DISABLE_WRAPPER: - return original_function(*args, **kwargs) - - if COLLECT_OVERHEAD_METRICS: - ENTER_PERF_TIME = time.perf_counter() - - func_call_id = get_unique_id() - process_id, thread_id = get_process_thread_id() - increment_step_if_needed( - original_function, original_function_name, is_bound_method, args - ) - - pre_meta_vars = get_meta_vars() - - if IS_INSTRUMENTING: - return original_function( - *args, **kwargs - ) # don't instrument while instrumenting - - pre_record = { - "func_call_id": func_call_id, - "thread_id": thread_id, - "process_id": process_id, - "meta_vars": pre_meta_vars, - "type": TraceLineType.FUNC_CALL_PRE, - "function": original_function_name, - "is_bound_method": is_bound_method, - "obj_id": None if not is_bound_method else id(args[0]), - } - - if dump_stack_trace: - pre_record["stack_trace"] = traceback.format_stack() - - if scan_proxy_in_args: - proxy_in_args = [] - - def find_proxy_in_args(args): - for i, arg in enumerate(args): - if is_proxied(arg) or is_proxyparameter(arg): - proxy_in_args.append(arg) - elif type(arg) in [list, tuple]: - find_proxy_in_args(arg) - elif isinstance(arg, types.GeneratorType) and not isinstance( - arg, tuple - ): - arg_list = list(arg) - args[i] = iter(arg_list) - find_proxy_in_args(arg_list) - - args = list(args) # type: ignore[assignment] - find_proxy_in_args(args) - args = tuple(args) - - if proxy_in_args: - if "proxy_obj_names" not in pre_record: - pre_record["proxy_obj_names"] = [] - for proxy in proxy_in_args: - if is_proxyparameter(proxy): - pre_record["proxy_obj_names"].append( - [proxy.__dict__["var_name"], "Parameter"] - ) - else: - pre_record["proxy_obj_names"].append( - [proxy.__dict__["var_name"], type(proxy._obj).__name__] - ) - if dump_args: - dict_args_kwargs = to_dict_args_kwargs(args, kwargs, dump_args_config) - pre_record["args"] = dict_args_kwargs["args"] - pre_record["kwargs"] = dict_args_kwargs["kwargs"] - dump_trace_API(pre_record) - - try: - if COLLECT_OVERHEAD_METRICS: - ORIG_ENTER_PERF_TIME = time.perf_counter() - result = original_function(*args, **kwargs) - if COLLECT_OVERHEAD_METRICS: - ORIG_EXIT_PERF_TIME = time.perf_counter() - except Exception as e: - if COLLECT_OVERHEAD_METRICS: - ORIG_EXIT_PERF_TIME = time.perf_counter() - - dump_trace_API( - { - "func_call_id": func_call_id, - "thread_id": thread_id, - "process_id": process_id, - "meta_vars": pre_meta_vars, - "type": TraceLineType.FUNC_CALL_POST_EXCEPTION, - "function": original_function_name, - "exception": typename(e, is_runtime=True), - "exception_msg": str(e), - "is_bound_method": is_bound_method, - "obj_id": None if not is_bound_method else id(args[0]), - }, - ) - - if COLLECT_OVERHEAD_METRICS: - EXIT_PERF_TIME = time.perf_counter() - print( - f"WRAPPER TIME: {original_function_name},{ORIG_EXIT_PERF_TIME - ORIG_ENTER_PERF_TIME},{EXIT_PERF_TIME - ENTER_PERF_TIME}" - ) - raise e - - post_record = { - "func_call_id": func_call_id, - "thread_id": thread_id, - "process_id": process_id, - "meta_vars": pre_meta_vars, - "type": TraceLineType.FUNC_CALL_POST, - "function": original_function_name, - "is_bound_method": is_bound_method, - "obj_id": None if not is_bound_method else id(args[0]), - } - - result_to_dump = result - - # if the current function name is transformers.generate, then we will dump the response tokens only, let's see. - # a concrete name: "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration.generate" - # we want a pattern that abstracts the specific model name - pattern = "transformers.models.*.generate" - # find matches in the pattern - import re - - if ( - GENERATE_START_TOKEN_ID is not None - and re.match(pattern, original_function_name) - and isinstance(result, torch.Tensor) - ): - print(f"Found match for {original_function_name}") - # the first dimension is the batch size, and each corresponds to a separate response, let's try to match the batch size with the start token ids first - response_starting_indices = [] - for i in range(result.size(0)): - # try to find the match of the start token ids in the response - response = result[i] - # Find all indices where the start_token_id matches - matches = (response == GENERATE_START_TOKEN_ID).nonzero(as_tuple=True)[0] - indexes = matches.tolist() - if len(indexes) == 0: - # No occurrences found - print( - f"start_token_id ({GENERATE_START_TOKEN_ID}) not found in response {i}" - ) - start_index = -1 # Handle case where token is not found - elif len(indexes) > 1: - # Multiple occurrences found, raise an error - raise ValueError( - f"Multiple occurrences of start_token_id ({GENERATE_START_TOKEN_ID}) found in response {i}: {matches.tolist()}" - ) - else: - # Single occurrence found, get the index - start_index = indexes[0] - if not GENERATE_START_TOKEN_ID_INCLUDE_START_TOKEN: - start_index += 1 - - response_starting_indices.append(start_index) - - # compute the length of each response - response_lengths = [] - for i in range(result.size(0)): - response = result[i] - start_index = response_starting_indices[i] - if start_index == -1: - response_lengths.append(0) - else: - response_lengths.append(response.size(0) - start_index) - - result_to_dump = result.detach() - setattr( - result_to_dump, - "_ML_DAIKON_RESPONSE_STARTING_INDICES", - response_starting_indices, - ) - setattr(result_to_dump, "_ML_DAIKON_RESPONSE_LENGTHS", response_lengths) - - print(response_starting_indices) - print(response_lengths) - if dump_ret: - post_record["return_values"] = to_dict_return_value(result_to_dump) - dump_trace_API(post_record) - - if COLLECT_OVERHEAD_METRICS: - EXIT_PERF_TIME = time.perf_counter() - print( - f"WRAPPER TIME: {original_function_name},{ORIG_EXIT_PERF_TIME - ORIG_ENTER_PERF_TIME},{EXIT_PERF_TIME - ENTER_PERF_TIME}" - ) - return result - - -def global_wrapper_proxy( +def function_wrapper( original_function: Callable, original_function_name: str, is_bound_method: bool, @@ -374,18 +140,13 @@ def global_wrapper_proxy( *args, **kwargs, ): - """Instrumentation for APIs with proxy-specific handling.""" - - # if "step" in original_function_name and not "scheduler" in original_function_name: - # print("step function called" + original_function_name) - # print(trigger_proxy_state_dump) - # print(proxy_state_dump_config) - # exit(1) + """Instrumentation for Function""" global DISABLE_WRAPPER global PROCESS_ID if DISABLE_WRAPPER: + # TODO: all meta vars update should be done outside the function_wrapper (e.g. step increment) by applying a separate wrapper return original_function(*args, **kwargs) if COLLECT_OVERHEAD_METRICS: @@ -393,9 +154,6 @@ def global_wrapper_proxy( func_call_id = get_unique_id() process_id, thread_id = get_process_thread_id() - increment_step_if_needed( - original_function, original_function_name, is_bound_method, args - ) pre_meta_vars = get_meta_vars() @@ -458,7 +216,7 @@ def find_proxy_in_args(args): if handle_proxy and trigger_proxy_state_dump: """Mimicking the behavior the observer wrapper: pre-observe""" - get_global_registry().dump_only_modified( + get_global_registry().dump_modified( dump_loc=original_function_name, dump_config=proxy_state_dump_config ) @@ -470,12 +228,10 @@ def find_proxy_in_args(args): original_function = add_observer_to_func(original_function, unproxy=True) elif is_funcs_to_be_unproxied(original_function): - original_function = unproxy_func( - original_function, inspect_torch_module=True - ) + args, kwargs = unproxy_args_kwargs(args, kwargs) elif is_builtin: # proxy objects being passed to backend will cause seg fault: TODO: replace with unproxy func - original_function = unproxy_func(original_function) + args, kwargs = unproxy_args_kwargs(args, kwargs) try: if COLLECT_OVERHEAD_METRICS: @@ -489,7 +245,7 @@ def find_proxy_in_args(args): if handle_proxy and trigger_proxy_state_dump: """Mimicking the behavior the observer wrapper: post-observe""" - get_global_registry().dump_only_modified( + get_global_registry().dump_modified( dump_loc=original_function_name, dump_config=proxy_state_dump_config ) @@ -516,7 +272,7 @@ def find_proxy_in_args(args): raise e if handle_proxy and trigger_proxy_state_dump: - get_global_registry().dump_only_modified( + get_global_registry().dump_modified( dump_loc=original_function_name, dump_config=proxy_state_dump_config ) @@ -606,17 +362,12 @@ def find_proxy_in_args(args): def core_wrapper_proxy(original_function, is_builtin, handle_proxy, *args, **kwargs): - """Same as global_wrapper_proxy but without logging. - - We use this wrapper on functions that are not helpful for invariant inference, - but still need proxy-safe handling. - """ + """Core wrapper that only handles unproxying for built-in functions.""" global DISABLE_WRAPPER if DISABLE_WRAPPER: return original_function(*args, **kwargs) - if handle_proxy and is_builtin: - original_function = unproxy_func(original_function) + args, kwargs = unproxy_args_kwargs(args, kwargs) return original_function(*args, **kwargs) @@ -636,40 +387,34 @@ def wrapper( ): is_builtin = is_c_level_function(original_function) original_function_name = typename(original_function) + increment_step = False + if original_function_name.endswith(".step") and isinstance( + original_function.__self__, torch.optim.Optimizer + ): + increment_step = True + # determine statically whether to dump the trace if not disable_dump: METRIC_INSTRUMENTED_FUNC_LIST["dump"].append(original_function_name) @functools.wraps(original_function) def wrapped(*args, **kwargs): - if handle_proxy: - return global_wrapper_proxy( - original_function, - original_function_name, - is_bound_method, - is_builtin, - scan_proxy_in_args, - dump_stack_trace, - dump_args, - dump_args_config, - dump_ret, - dump_ret_config, - handle_proxy, - trigger_proxy_state_dump, - proxy_state_dump_config, - *args, - **kwargs, - ) - return global_wrapper_subclass( - original_function, - original_function_name, - is_bound_method, - scan_proxy_in_args, - dump_stack_trace, - dump_args, - dump_args_config, - dump_ret, - dump_ret_config, + if increment_step: + META_VARS["step"] += 1 + return function_wrapper( + original_function=original_function, + original_function_name=original_function_name, + is_bound_method=is_bound_method, + is_builtin=is_builtin, + scan_proxy_in_args=scan_proxy_in_args, + dump_stack_trace=dump_stack_trace, + dump_args=dump_args, + dump_args_config=dump_args_config, + dump_ret=dump_ret, + dump_ret_config=dump_ret_config, + handle_proxy=handle_proxy, + trigger_proxy_state_dump=trigger_proxy_state_dump, + proxy_state_dump_config=proxy_state_dump_config, *args, **kwargs, ) @@ -680,11 +425,15 @@ def wrapped(*args, **kwargs): @functools.wraps(original_function) def wrapped(*args, **kwargs): + if increment_step: + META_VARS["step"] += 1 return core_wrapper_proxy( original_function, is_builtin, handle_proxy, *args, **kwargs ) else: + if increment_step: + META_VARS["step"] += 1 return original_function wrapped._traincheck_original_function = original_function @@ -1076,7 +825,18 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable: else config.MODEL_TRACKER_STYLE ) used_proxy = tracker_style == "proxy" - if self.instr_opts is not None: + if self.instr_opts is None: + # inference stage instrumentation + return wrapper( + func_obj, + is_bound_method=is_API_bound_method(func_obj), + scan_proxy_in_args=self.scan_proxy_in_args, + disable_dump=self.should_disable_dump(func_obj), + dump_stack_trace=self.API_dump_stack_trace, + handle_proxy=used_proxy, + ) + else: + # checking stage instrumentation func_name = typename(func_obj) if func_name not in self.instr_opts.funcs_instr_opts: return wrapper( @@ -1113,15 +873,6 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable: proxy_state_dump_config=func_instr_opt["var_types_to_track"], ) - return wrapper( - func_obj, - is_bound_method=is_API_bound_method(func_obj), - scan_proxy_in_args=self.scan_proxy_in_args, - disable_dump=self.should_disable_dump(func_obj), - dump_stack_trace=self.API_dump_stack_trace, - handle_proxy=used_proxy, - ) - def _instrument_module( self, pymodule: types.ModuleType | type, From 18b6a8be7e78bb0102eba26e60fc7a1f746ee873 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Mon, 5 Jan 2026 15:11:15 -0500 Subject: [PATCH 03/15] remove unused proxy config; rename ML_DAIKON to TRAINCHECK --- docs/5-min-tutorial.md | 4 +-- docs/ae-eval-s5.1-silent-issue-detection.md | 32 ++++++++--------- docs/assets/code/mnist.py | 16 ++++----- .../traincheck-collect/mnist-config/mnist.py | 16 ++++----- traincheck/collect_trace.py | 9 +---- traincheck/instrumentor/dumper.py | 22 ++++++------ .../instrumentor/proxy_wrapper/proxy.py | 2 +- .../proxy_wrapper/proxy_config.py | 2 -- traincheck/instrumentor/source_file.py | 6 ++-- traincheck/instrumentor/tracer.py | 35 ++++++++----------- .../invariant/consistency_transient_vars.py | 2 +- 11 files changed, 65 insertions(+), 81 deletions(-) diff --git a/docs/5-min-tutorial.md b/docs/5-min-tutorial.md index 5caf26a5..27fe4004 100644 --- a/docs/5-min-tutorial.md +++ b/docs/5-min-tutorial.md @@ -246,7 +246,7 @@ For example, the "`optimizer.zero_grad` did **not** reset `.grad` from non-zero "var_type": NaN, "mode": NaN, "dump_loc": NaN, - "attributes._ML_DAIKON_data_ID": NaN, + "attributes._TRAINCHECK_data_ID": NaN, "attributes.data": NaN, "attributes.dtype": NaN, "attributes.grad": NaN, @@ -274,7 +274,7 @@ For example, the "`optimizer.zero_grad` did **not** reset `.grad` from non-zero "attributes.requires_grad": NaN, "attributes.retains_grad": NaN, "attributes.shape": NaN, - "attributes._ML_DAIKON_grad_ID": NaN, + "attributes._TRAINCHECK_grad_ID": NaN, "exception": NaN, "exception_msg": NaN, "proxy_obj_names": NaN diff --git a/docs/ae-eval-s5.1-silent-issue-detection.md b/docs/ae-eval-s5.1-silent-issue-detection.md index b6065c5d..e0499b57 100644 --- a/docs/ae-eval-s5.1-silent-issue-detection.md +++ b/docs/ae-eval-s5.1-silent-issue-detection.md @@ -145,9 +145,9 @@ diff --color -r checker_output/trace_pytorch-104336/failed.log reference_checker > "process_id": 9591, > "thread_id": 140324043503424, 86c86 -< "attributes._ML_DAIKON_data_ID": 140704882109040, +< "attributes._TRAINCHECK_data_ID": 140704882109040, --- -> "attributes._ML_DAIKON_data_ID": 140317529048544, +> "attributes._TRAINCHECK_data_ID": 140317529048544, 116,117c116,117 < "time": 2437523672783, < "meta_vars._DATA_PARALLEL_RANK": 4.0, @@ -161,9 +161,9 @@ diff --color -r checker_output/trace_pytorch-104336/failed.log reference_checker > "process_id": 9747, > "thread_id": 140028492969792, 128c128 -< "attributes._ML_DAIKON_data_ID": 140043703504144, +< "attributes._TRAINCHECK_data_ID": 140043703504144, --- -> "attributes._ML_DAIKON_data_ID": 140021978318304, +> "attributes._TRAINCHECK_data_ID": 140021978318304, 158,159c158,159 < "time": 2437502499438, < "meta_vars._DATA_PARALLEL_RANK": 2.0, @@ -182,9 +182,9 @@ diff --color -r checker_output/trace_pytorch-115607/failed.log reference_checker < "exception_msg": NaN, < "proxy_obj_names": NaN, 113c110,113 -< "attributes._ML_DAIKON_grad_ID": NaN +< "attributes._TRAINCHECK_grad_ID": NaN --- -> "attributes._ML_DAIKON_grad_ID": NaN, +> "attributes._TRAINCHECK_grad_ID": NaN, > "exception": NaN, > "exception_msg": NaN, > "proxy_obj_names": NaN @@ -193,9 +193,9 @@ diff --color -r checker_output/trace_pytorch-115607/failed.log reference_checker < "exception_msg": NaN, < "proxy_obj_names": NaN, 215c212,215 -< "attributes._ML_DAIKON_grad_ID": NaN +< "attributes._TRAINCHECK_grad_ID": NaN --- -> "attributes._ML_DAIKON_grad_ID": NaN, +> "attributes._TRAINCHECK_grad_ID": NaN, > "exception": NaN, > "exception_msg": NaN, > "proxy_obj_names": NaN @@ -210,9 +210,9 @@ diff --color -r checker_output/trace_pytorch-115607/failed.log reference_checker < "exception_msg": NaN, < "proxy_obj_names": NaN, 331c328,331 -< "attributes._ML_DAIKON_grad_ID": NaN +< "attributes._TRAINCHECK_grad_ID": NaN --- -> "attributes._ML_DAIKON_grad_ID": NaN, +> "attributes._TRAINCHECK_grad_ID": NaN, > "exception": NaN, > "exception_msg": NaN, > "proxy_obj_names": NaN @@ -247,10 +247,10 @@ diff --color -r checker_output/trace_pytorch-51800/failed.log reference_checker_ > "time": 19876858668088743, > "meta_vars.step": 0, 89c70,89 -< "attributes._ML_DAIKON_grad_ID": NaN +< "attributes._TRAINCHECK_grad_ID": NaN --- > "type": "function_call (pre)", -> "attributes._ML_DAIKON_grad_ID": NaN, +> "attributes._TRAINCHECK_grad_ID": NaN, > "func_call_id": "b39a4a81b2c24473ba916ab1832fbf12_19876858668012869", > "function": "torch.nn.modules.module.Module.eval", > "is_bound_method": true, @@ -290,9 +290,9 @@ diff --color -r checker_output/trace_x-jxmnop-ddp-out-of-sync/failed.log referen --- > "meta_vars._DATA_PARALLEL_RANK": "1", 87c87 -< "attributes._ML_DAIKON_data_ID": 140656561409856, +< "attributes._TRAINCHECK_data_ID": 140656561409856, --- -> "attributes._ML_DAIKON_data_ID": 140621279056480, +> "attributes._TRAINCHECK_data_ID": 140621279056480, 117c117 < "time": 123297988837864, --- @@ -308,9 +308,9 @@ diff --color -r checker_output/trace_x-jxmnop-ddp-out-of-sync/failed.log referen --- > "meta_vars._DATA_PARALLEL_RANK": "0", 129c129 -< "attributes._ML_DAIKON_data_ID": 140621279058160, +< "attributes._TRAINCHECK_data_ID": 140621279058160, --- -> "attributes._ML_DAIKON_data_ID": 140656561411776, +> "attributes._TRAINCHECK_data_ID": 140656561411776, 159c159 < "time": 123299970638648, --- diff --git a/docs/assets/code/mnist.py b/docs/assets/code/mnist.py index 6952ed8a..b5a34a5b 100644 --- a/docs/assets/code/mnist.py +++ b/docs/assets/code/mnist.py @@ -40,7 +40,7 @@ def forward(self, x): def train(args, model, device, train_loader, optimizer, epoch): - annotate_stage("training") # ML_DAIKON: stage annotation + annotate_stage("training") # TRAINCHECK: stage annotation model.train() for batch_idx, (data, target) in enumerate(train_loader): META_VARS["step"] += 1 @@ -63,13 +63,13 @@ def train(args, model, device, train_loader, optimizer, epoch): if args.dry_run: break - # ML_DAIKON: break after 100 batches + # TRAINCHECK: break after 100 batches if batch_idx == 50: break def test(model, device, test_loader): - annotate_stage("testing") # ML_DAIKON: stage annotation + annotate_stage("testing") # TRAINCHECK: stage annotation model.eval() test_loss = 0 correct = 0 @@ -87,7 +87,7 @@ def test(model, device, test_loader): correct += pred.eq(target.view_as(pred)).sum().item() data_idx += 1 - # ML_DAIKON: break after 10 batches + # TRAINCHECK: break after 10 batches if data_idx == 10: break @@ -174,7 +174,7 @@ def main(): ) args = parser.parse_args() - annotate_stage("init") # ML_DAIKON: stage annotation + annotate_stage("init") # TRAINCHECK: stage annotation use_cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() @@ -191,7 +191,7 @@ def main(): test_kwargs = {"batch_size": args.test_batch_size} if use_cuda: cuda_kwargs = {"num_workers": 2, "pin_memory": True, "shuffle": True} - # ML_DAIKON: set num_workers to 0 to avoid dataloader related invariants + # TRAINCHECK: set num_workers to 0 to avoid dataloader related invariants # cuda_kwargs = {'num_workers': 0, 'pin_memory': True, 'shuffle': True} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) @@ -212,11 +212,11 @@ def main(): train(args, model, device, train_loader, optimizer, epoch) test(model, device, test_loader) - annotate_stage("training") # ML_DAIKON: stage annotation + annotate_stage("training") # TRAINCHECK: stage annotation scheduler.step() if args.save_model: - annotate_stage("checkpointing") # ML_DAIKON: stage annotation + annotate_stage("checkpointing") # TRAINCHECK: stage annotation torch.save(model.state_dict(), "mnist_cnn.pt") diff --git a/docs/assets/examples/traincheck-collect/mnist-config/mnist.py b/docs/assets/examples/traincheck-collect/mnist-config/mnist.py index 6952ed8a..b5a34a5b 100644 --- a/docs/assets/examples/traincheck-collect/mnist-config/mnist.py +++ b/docs/assets/examples/traincheck-collect/mnist-config/mnist.py @@ -40,7 +40,7 @@ def forward(self, x): def train(args, model, device, train_loader, optimizer, epoch): - annotate_stage("training") # ML_DAIKON: stage annotation + annotate_stage("training") # TRAINCHECK: stage annotation model.train() for batch_idx, (data, target) in enumerate(train_loader): META_VARS["step"] += 1 @@ -63,13 +63,13 @@ def train(args, model, device, train_loader, optimizer, epoch): if args.dry_run: break - # ML_DAIKON: break after 100 batches + # TRAINCHECK: break after 100 batches if batch_idx == 50: break def test(model, device, test_loader): - annotate_stage("testing") # ML_DAIKON: stage annotation + annotate_stage("testing") # TRAINCHECK: stage annotation model.eval() test_loss = 0 correct = 0 @@ -87,7 +87,7 @@ def test(model, device, test_loader): correct += pred.eq(target.view_as(pred)).sum().item() data_idx += 1 - # ML_DAIKON: break after 10 batches + # TRAINCHECK: break after 10 batches if data_idx == 10: break @@ -174,7 +174,7 @@ def main(): ) args = parser.parse_args() - annotate_stage("init") # ML_DAIKON: stage annotation + annotate_stage("init") # TRAINCHECK: stage annotation use_cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() @@ -191,7 +191,7 @@ def main(): test_kwargs = {"batch_size": args.test_batch_size} if use_cuda: cuda_kwargs = {"num_workers": 2, "pin_memory": True, "shuffle": True} - # ML_DAIKON: set num_workers to 0 to avoid dataloader related invariants + # TRAINCHECK: set num_workers to 0 to avoid dataloader related invariants # cuda_kwargs = {'num_workers': 0, 'pin_memory': True, 'shuffle': True} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) @@ -212,11 +212,11 @@ def main(): train(args, model, device, train_loader, optimizer, epoch) test(model, device, test_loader) - annotate_stage("training") # ML_DAIKON: stage annotation + annotate_stage("training") # TRAINCHECK: stage annotation scheduler.step() if args.save_model: - annotate_stage("checkpointing") # ML_DAIKON: stage annotation + annotate_stage("checkpointing") # TRAINCHECK: stage annotation torch.save(model.state_dict(), "mnist_cnn.pt") diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 4f9eb7e6..6c4c1285 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -347,12 +347,6 @@ def main(): default="hash", help="The format for dumping tensors. Choose from 'hash'(default), 'stats' or 'full'.", ) - parser.add_argument( - "--enable-C-level-observer", - type=bool, - default=proxy_config.enable_C_level_observer, - help="Enable the observer at the C level", - ) parser.add_argument( "--no-auto-var-instr", action="store_true", @@ -386,7 +380,7 @@ def main(): # set up logging if args.debug_mode: logging.basicConfig(level=logging.DEBUG) - os.environ["ML_DAIKON_DEBUG"] = "1" + os.environ["TRAINCHECK_DEBUG"] = "1" else: logging.basicConfig(level=logging.INFO) @@ -406,7 +400,6 @@ def main(): proxy_basic_config: dict[str, int | bool | str] = {} for configs in [ "debug_mode", - "enable_C_level_observer", ]: if getattr(proxy_config, configs) != getattr(args, configs): proxy_basic_config[configs] = getattr(args, configs) diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index cbc745f5..a7471053 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -27,7 +27,7 @@ ) from traincheck.utils import get_timestamp_ns, typename, typename_compile -DEBUG = os.environ.get("ML_DAIKON_DEBUG", False) +DEBUG = os.environ.get("TRAINCHECK_DEBUG", False) THREAD_DATA = threading.local() IS_CUDA_AVAILABLE = torch.cuda.is_available() @@ -129,10 +129,10 @@ def get_trace_API_dumper_queue(): pid = os.getpid() tid = threading.get_ident() - output_dir = os.getenv("ML_DAIKON_OUTPUT_DIR") + output_dir = os.getenv("TRAINCHECK_OUTPUT_DIR") assert ( output_dir is not None - ), "ML_DAIKON_OUTPUT_DIR is not set, examine the instrumented code to see if os.environ['ML_DAIKON_OUTPUT_DIR'] is set in the main function" + ), "TRAINCHECK_OUTPUT_DIR is not set, examine the instrumented code to see if os.environ['TRAINCHECK_OUTPUT_DIR'] is set in the main function" trace_queue = Queue() trace_file_name = f"trace_API_{pid}_{tid}.log" @@ -161,10 +161,10 @@ def get_trace_VAR_dumper_queue(): pid = os.getpid() tid = threading.current_thread().ident - output_dir = os.getenv("ML_DAIKON_OUTPUT_DIR") + output_dir = os.getenv("TRAINCHECK_OUTPUT_DIR") assert ( output_dir is not None - ), "ML_DAIKON_OUTPUT_DIR is not set, examine the instrumented code to see if os.environ['ML_DAIKON_OUTPUT_DIR'] is set in the main function" + ), "TRAINCHECK_OUTPUT_DIR is not set, examine the instrumented code to see if os.environ['TRAINCHECK_OUTPUT_DIR'] is set in the main function" trace_queue = Queue() trace_file_name = f"trace_VAR_{pid}_{tid}.log" @@ -249,10 +249,10 @@ def dump_trace_VAR(trace: dict): def get_instrumentation_logger_for_process(): pid = os.getpid() - output_dir = os.getenv("ML_DAIKON_OUTPUT_DIR") + output_dir = os.getenv("TRAINCHECK_OUTPUT_DIR") assert ( output_dir is not None - ), "ML_DAIKON_OUTPUT_DIR is not set, examine the instrumented code to see if os.environ['ML_DAIKON_OUTPUT_DIR'] is set in the main function" + ), "TRAINCHECK_OUTPUT_DIR is not set, examine the instrumented code to see if os.environ['TRAINCHECK_OUTPUT_DIR'] is set in the main function" if pid in instrumentation_loggers: return instrumentation_loggers[pid] @@ -369,7 +369,7 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict if ( isinstance(attr_name, str) and attr_name.startswith("_") - and not attr_name.startswith("_ML_DAIKON") + and not attr_name.startswith("_TRAINCHECK") ): continue @@ -405,12 +405,12 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict result[attr_name] = attr elif isinstance(attr, torch.Tensor): - result[f"_ML_DAIKON_{attr_name}_ID"] = id(attr) + result[f"_TRAINCHECK_{attr_name}_ID"] = id(attr) if include_tensor_data: result[attr_name] = dump_tensor(attr) elif isinstance(attr, torch.nn.parameter.Parameter): - result[f"_ML_DAIKON_{attr_name}_ID"] = id(attr) + result[f"_TRAINCHECK_{attr_name}_ID"] = id(attr) if include_tensor_data: result[attr_name] = dump_tensor(attr.data) @@ -430,7 +430,7 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict result[attr_name] = str(attr) elif isinstance(attr, torch.Size): result[attr_name] = tuple(attr) - elif "_ML_DAIKON" in attr_name: + elif "_TRAINCHECK" in attr_name: # should always be serializable, so blindly assign here. result[attr_name] = attr diff --git a/traincheck/instrumentor/proxy_wrapper/proxy.py b/traincheck/instrumentor/proxy_wrapper/proxy.py index f9deb9ee..3d48bbe5 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy.py @@ -102,7 +102,7 @@ class Proxy: var_dict: Dict[str, ProxyObjInfo] = {} loglevel = logging.INFO jsondumper = dumper( - os.path.join(os.getenv("ML_DAIKON_OUTPUT_DIR", "."), "proxy_log.json") # type: ignore + os.path.join(os.getenv("TRAINCHECK_OUTPUT_DIR", "."), "proxy_log.json") # type: ignore ) @staticmethod diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_config.py b/traincheck/instrumentor/proxy_wrapper/proxy_config.py index 66ce6d7c..c90d5d91 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_config.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_config.py @@ -25,8 +25,6 @@ "observe_then_unproxy": True, # observe the function call and then unproxy the arguments } -enable_C_level_observer = False # enable the observer at the C level (This would potentially lead to a lot of overhead since we need to observe and dump all proxied object at the C level function call, try to use auto observer with proper depth could reduce the overhead) - primitive_types = { types.NoneType, int, diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 47c0ab4c..7eb4baad 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -833,13 +833,13 @@ def instrument_file( # logging configs logging_start_code = f""" import os -os.environ['ML_DAIKON_OUTPUT_DIR'] = "{output_dir}" +os.environ['TRAINCHECK_OUTPUT_DIR'] = "{output_dir}" """ debug_hook_code = """ from traincheck.utils import register_custom_excepthook -if os.environ.get("ML_DAIKON_DEBUG") == "1": - print("ML_DAIKON_DEBUG is set to 1, registering custom excepthook") +if os.environ.get("TRAINCHECK_DEBUG") == "1": + print("TRAINCHECK_DEBUG is set to 1, registering custom excepthook") register_custom_excepthook(True) """ diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index 76fc7db1..58bfa9f1 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -30,7 +30,6 @@ is_proxyparameter, unproxy_args_kwargs, ) -from traincheck.instrumentor.proxy_wrapper.proxy_config import enable_C_level_observer from traincheck.instrumentor.proxy_wrapper.proxy_registry import get_global_registry from traincheck.instrumentor.replace_functions import ( funcs_to_be_replaced, @@ -127,7 +126,6 @@ def function_wrapper( original_function: Callable, original_function_name: str, is_bound_method: bool, - is_builtin: bool, scan_proxy_in_args: bool, dump_stack_trace: bool, dump_args: bool, @@ -137,6 +135,7 @@ def function_wrapper( handle_proxy: bool, trigger_proxy_state_dump: bool, proxy_state_dump_config: dict, + need_unproxy_args_kwargs: bool, *args, **kwargs, ): @@ -158,9 +157,9 @@ def function_wrapper( pre_meta_vars = get_meta_vars() if IS_INSTRUMENTING: - return original_function( - *args, **kwargs - ) # don't instrument while instrumenting + # during instrumentation, skip the dumping to avoid infinite recursion + # and interference with the import system + return original_function(*args, **kwargs) pre_record = { "func_call_id": func_call_id, @@ -177,6 +176,7 @@ def function_wrapper( pre_record["stack_trace"] = traceback.format_stack() if scan_proxy_in_args: + # TODO: can be optimized: use static or dynamic analysis to determine which args/kwargs to scan proxy_in_args = [] def find_proxy_in_args(args): @@ -214,24 +214,14 @@ def find_proxy_in_args(args): pre_record["kwargs"] = dict_args_kwargs["kwargs"] dump_trace_API(pre_record) - if handle_proxy and trigger_proxy_state_dump: + if trigger_proxy_state_dump: """Mimicking the behavior the observer wrapper: pre-observe""" get_global_registry().dump_modified( dump_loc=original_function_name, dump_config=proxy_state_dump_config ) - if handle_proxy: - if enable_C_level_observer and is_builtin: - from traincheck.instrumentor.proxy_wrapper.proxy_observer import ( - add_observer_to_func, # import here to avoid circular import - ) - - original_function = add_observer_to_func(original_function, unproxy=True) - elif is_funcs_to_be_unproxied(original_function): - args, kwargs = unproxy_args_kwargs(args, kwargs) - elif is_builtin: - # proxy objects being passed to backend will cause seg fault: TODO: replace with unproxy func - args, kwargs = unproxy_args_kwargs(args, kwargs) + if need_unproxy_args_kwargs: + args, kwargs = unproxy_args_kwargs(args, kwargs) try: if COLLECT_OVERHEAD_METRICS: @@ -342,10 +332,10 @@ def find_proxy_in_args(args): result_to_dump = result.detach() setattr( result_to_dump, - "_ML_DAIKON_RESPONSE_STARTING_INDICES", + "_TRAINCHECK_RESPONSE_STARTING_INDICES", response_starting_indices, ) - setattr(result_to_dump, "_ML_DAIKON_RESPONSE_LENGTHS", response_lengths) + setattr(result_to_dump, "_TRAINCHECK_RESPONSE_LENGTHS", response_lengths) print(response_starting_indices) print(response_lengths) @@ -386,6 +376,9 @@ def wrapper( proxy_state_dump_config=None, ): is_builtin = is_c_level_function(original_function) + need_unproxy_args_kwargs = handle_proxy and ( + is_builtin or is_funcs_to_be_unproxied(original_function) + ) original_function_name = typename(original_function) increment_step = False if original_function_name.endswith(".step") and isinstance( @@ -405,7 +398,6 @@ def wrapped(*args, **kwargs): original_function=original_function, original_function_name=original_function_name, is_bound_method=is_bound_method, - is_builtin=is_builtin, scan_proxy_in_args=scan_proxy_in_args, dump_stack_trace=dump_stack_trace, dump_args=dump_args, @@ -415,6 +407,7 @@ def wrapped(*args, **kwargs): handle_proxy=handle_proxy, trigger_proxy_state_dump=trigger_proxy_state_dump, proxy_state_dump_config=proxy_state_dump_config, + need_unproxy_args_kwargs=need_unproxy_args_kwargs, *args, **kwargs, ) diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py index b0640dff..485800e9 100644 --- a/traincheck/invariant/consistency_transient_vars.py +++ b/traincheck/invariant/consistency_transient_vars.py @@ -34,7 +34,7 @@ TENSOR_PATTERN = r"torch\..*Tensor" PARAMETER_KEYWORD = "Parameter" -ATTR_SKIP = "_ML_DAIKON_data_ID" +ATTR_SKIP = "_TRAINCHECK_data_ID" # _CACHE_PATH = "func_with_tensors.pkl" From 730463d71a67cd36132828c5be89ebe394b0af21 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Tue, 6 Jan 2026 17:21:11 -0500 Subject: [PATCH 04/15] subclass selective instrumentation impl --- .../instrumentor/proxy_wrapper/subclass.py | 2 - traincheck/instrumentor/tracer.py | 41 ++++++++----------- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/traincheck/instrumentor/proxy_wrapper/subclass.py b/traincheck/instrumentor/proxy_wrapper/subclass.py index c9fa3973..acb64a18 100644 --- a/traincheck/instrumentor/proxy_wrapper/subclass.py +++ b/traincheck/instrumentor/proxy_wrapper/subclass.py @@ -166,8 +166,6 @@ def update_timestamp(self): def register_object(self): get_global_registry().add_var(self, self.__dict__["var_name"]) - # TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object - # pass def dump_trace(self, phase, dump_loc): # print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index 58bfa9f1..d3204636 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -132,9 +132,8 @@ def function_wrapper( dump_args_config, dump_ret: bool, dump_ret_config, - handle_proxy: bool, - trigger_proxy_state_dump: bool, - proxy_state_dump_config: dict, + trigger_var_dump: bool, + var_dump_config: dict, need_unproxy_args_kwargs: bool, *args, **kwargs, @@ -176,7 +175,6 @@ def function_wrapper( pre_record["stack_trace"] = traceback.format_stack() if scan_proxy_in_args: - # TODO: can be optimized: use static or dynamic analysis to determine which args/kwargs to scan proxy_in_args = [] def find_proxy_in_args(args): @@ -214,10 +212,10 @@ def find_proxy_in_args(args): pre_record["kwargs"] = dict_args_kwargs["kwargs"] dump_trace_API(pre_record) - if trigger_proxy_state_dump: + if trigger_var_dump: """Mimicking the behavior the observer wrapper: pre-observe""" get_global_registry().dump_modified( - dump_loc=original_function_name, dump_config=proxy_state_dump_config + dump_loc=original_function_name, dump_config=var_dump_config ) if need_unproxy_args_kwargs: @@ -233,10 +231,10 @@ def find_proxy_in_args(args): if COLLECT_OVERHEAD_METRICS: ORIG_EXIT_PERF_TIME = time.perf_counter() - if handle_proxy and trigger_proxy_state_dump: + if trigger_var_dump: """Mimicking the behavior the observer wrapper: post-observe""" get_global_registry().dump_modified( - dump_loc=original_function_name, dump_config=proxy_state_dump_config + dump_loc=original_function_name, dump_config=var_dump_config ) dump_trace_API( @@ -261,9 +259,9 @@ def find_proxy_in_args(args): ) raise e - if handle_proxy and trigger_proxy_state_dump: + if trigger_var_dump: get_global_registry().dump_modified( - dump_loc=original_function_name, dump_config=proxy_state_dump_config + dump_loc=original_function_name, dump_config=var_dump_config ) post_record = { @@ -351,7 +349,7 @@ def find_proxy_in_args(args): return result -def core_wrapper_proxy(original_function, is_builtin, handle_proxy, *args, **kwargs): +def core_wrapper_proxy(original_function, *args, **kwargs): """Core wrapper that only handles unproxying for built-in functions.""" global DISABLE_WRAPPER if DISABLE_WRAPPER: @@ -372,8 +370,8 @@ def wrapper( dump_ret=True, dump_ret_config=None, handle_proxy=True, - trigger_proxy_state_dump=False, - proxy_state_dump_config=None, + trigger_var_dump=False, + var_dump_config=None, ): is_builtin = is_c_level_function(original_function) need_unproxy_args_kwargs = handle_proxy and ( @@ -404,9 +402,8 @@ def wrapped(*args, **kwargs): dump_args_config=dump_args_config, dump_ret=dump_ret, dump_ret_config=dump_ret_config, - handle_proxy=handle_proxy, - trigger_proxy_state_dump=trigger_proxy_state_dump, - proxy_state_dump_config=proxy_state_dump_config, + trigger_var_dump=trigger_var_dump, + var_dump_config=var_dump_config, need_unproxy_args_kwargs=need_unproxy_args_kwargs, *args, **kwargs, @@ -414,15 +411,13 @@ def wrapped(*args, **kwargs): else: METRIC_INSTRUMENTED_FUNC_LIST["no_dump"].append(original_function_name) - if handle_proxy: + if need_unproxy_args_kwargs: @functools.wraps(original_function) def wrapped(*args, **kwargs): if increment_step: META_VARS["step"] += 1 - return core_wrapper_proxy( - original_function, is_builtin, handle_proxy, *args, **kwargs - ) + return core_wrapper_proxy(original_function, *args, **kwargs) else: if increment_step: @@ -817,7 +812,7 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable: if self.instr_opts is not None else config.MODEL_TRACKER_STYLE ) - used_proxy = tracker_style == "proxy" + used_proxy = tracker_style == "proxy" # TODO: refactor this: if self.instr_opts is None: # inference stage instrumentation return wrapper( @@ -861,9 +856,9 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable: else None ), handle_proxy=used_proxy, - trigger_proxy_state_dump=self.instr_opts.disable_proxy_dumping + trigger_var_dump=self.instr_opts.disable_proxy_dumping and len(func_instr_opt["var_types_to_track"]) > 0, - proxy_state_dump_config=func_instr_opt["var_types_to_track"], + var_dump_config=func_instr_opt["var_types_to_track"], ) def _instrument_module( From 060872d232428fa769fe21643564b126e8d6aa08 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 7 Jan 2026 14:10:57 -0500 Subject: [PATCH 05/15] fix instrumentation logic to get the parent class of a method definition --- traincheck/instrumentor/tracer.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index d3204636..8783e7a0 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -83,6 +83,23 @@ def get_meta_vars() -> dict: return META_VARS +def get_owner_class(func): + # Works for unbound functions defined on a class. + qualname = getattr(func, "__qualname__", "") + if "." not in qualname: + return None # not a class method + owner_path = qualname.rsplit(".", 1)[0] # e.g., "Optimizer" + mod = inspect.getmodule(func) + if mod is None: + mod = importlib.import_module(func.__module__) + owner = mod + for part in owner_path.split("."): + owner = getattr(owner, part, None) + if owner is None: + return None + return owner + + def to_dict_args_kwargs(args, kwargs, dump_args_config=None) -> dict: global DISABLE_WRAPPER DISABLE_WRAPPER = True @@ -379,11 +396,10 @@ def wrapper( ) original_function_name = typename(original_function) increment_step = False - if original_function_name.endswith(".step") and isinstance( - original_function.__self__, torch.optim.Optimizer - ): - increment_step = True - + if original_function_name.endswith(".step"): + owner = get_owner_class(original_function) + if isinstance(owner, torch.optim.Optimizer): + increment_step = True # determine statically whether to dump the trace if not disable_dump: METRIC_INSTRUMENTED_FUNC_LIST["dump"].append(original_function_name) From b14e0d4da5a3871843c6773cd5e003ad66d3e2e8 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 7 Jan 2026 14:10:57 -0500 Subject: [PATCH 06/15] fix: only use positional arguments for function_wrapper --- traincheck/instrumentor/tracer.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index 8783e7a0..c8988779 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -155,7 +155,12 @@ def function_wrapper( *args, **kwargs, ): - """Instrumentation for Function""" + """Instrumentation for Function + + When using this wrapper, pass in the control parameters as positional arguments, as any kwargs are passed to the original function. + If you used keyword arguments for the control parameters, you may see errors like: + TypeError: function_wrapper() got multiple values for argument 'arg_name' + """ global DISABLE_WRAPPER global PROCESS_ID @@ -409,18 +414,18 @@ def wrapped(*args, **kwargs): if increment_step: META_VARS["step"] += 1 return function_wrapper( - original_function=original_function, - original_function_name=original_function_name, - is_bound_method=is_bound_method, - scan_proxy_in_args=scan_proxy_in_args, - dump_stack_trace=dump_stack_trace, - dump_args=dump_args, - dump_args_config=dump_args_config, - dump_ret=dump_ret, - dump_ret_config=dump_ret_config, - trigger_var_dump=trigger_var_dump, - var_dump_config=var_dump_config, - need_unproxy_args_kwargs=need_unproxy_args_kwargs, + original_function, + original_function_name, + is_bound_method, + scan_proxy_in_args, + dump_stack_trace, + dump_args, + dump_args_config, + dump_ret, + dump_ret_config, + trigger_var_dump, + var_dump_config, + need_unproxy_args_kwargs, *args, **kwargs, ) From 2034dfdabcdbfb473df9a98648db1d5d19a4e4ff Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 7 Jan 2026 17:57:16 -0500 Subject: [PATCH 07/15] fix: respect configured tracker type during selective instrumentation --- traincheck/collect_trace.py | 41 +++++++++++++++++++++++++++++-------- traincheck/trace/types.py | 16 +++++++-------- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 6c4c1285..d68783be 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -137,24 +137,46 @@ def merge(a: dict, b: dict, path=[]): return func_instr_opts -def get_model_tracker_instr_opts(invariants: list[Invariant]) -> str | None: +def get_model_tracker_instr_opts( + invariants: list[Invariant], config_tracker_style: str +) -> str | None: """ Get model tracker instrumentation options """ - tracker_type = None + logger = logging.getLogger(__name__) + need_immediate_var_tracking = False + need_var_tracking = False for inv in invariants: if inv.relation == APIContainRelation: for param in inv.params: if isinstance(param, (VarNameParam, VarTypeParam)): - tracker_type = "proxy" + need_var_tracking = True + need_immediate_var_tracking = True break - if tracker_type is None and inv.relation == ConsistencyRelation: - tracker_type = "sampler" + if not need_var_tracking and inv.relation == ConsistencyRelation: + need_immediate_var_tracking = False + need_var_tracking = True - if tracker_type == "proxy": + if need_var_tracking and need_immediate_var_tracking: break - return tracker_type + + if need_immediate_var_tracking: + if config_tracker_style in ["proxy", "subclass"]: + return config_tracker_style + else: + logger.warning( + f"Model tracker style {config_tracker_style} is not suitable for immediate variable tracking, using 'subclass' by default instead." + ) + return "subclass" + elif need_var_tracking: + if not config_tracker_style == "sampler": + logger.warning( + f"Model tracker style {config_tracker_style} is not suitable for non-immediate variable tracking, using 'sampler' by default instead." + ) + return "sampler" + + return None def dump_env(args, output_dir: str): @@ -435,9 +457,12 @@ def main(): if args.invariants: # selective instrumentation if invariants are provided, only funcs_to_instr will be instrumented with trace collection invariants = read_inv_file(args.invariants) + instr_opts = InstrOpt( func_instr_opts=get_per_func_instr_opts(invariants), - model_tracker_style=get_model_tracker_instr_opts(invariants), + model_tracker_style=get_model_tracker_instr_opts( + invariants, args.model_tracker_style + ), disable_proxy_dumping=True, ) models_to_track = ( diff --git a/traincheck/trace/types.py b/traincheck/trace/types.py index 25c61ccd..1dc1bf64 100644 --- a/traincheck/trace/types.py +++ b/traincheck/trace/types.py @@ -154,17 +154,15 @@ def __init__(self, func_name: str, pre_record: dict, post_record: dict): ) # TODO: use the Arguments class to replace self.args and self.kwargs - self.args: dict[str, dict[str, dict[str, object]]] = pre_record[ - "args" - ] # lists of [type -> attr_name -> value] - self.kwargs: dict[str, dict[str, object]] = pre_record[ - "kwargs" - ] # key --> attr_name -> value + self.args: dict[str, dict[str, dict[str, object]]] = pre_record.get( + "args", {} + ) # lists of [type -> attr_name -> value] + self.kwargs: dict[str, dict[str, object]] = pre_record.get("kwargs", {}) self.return_values: ( dict[str, dict[str, object]] | list[dict[str, dict[str, object]]] - ) = post_record[ - "return_values" - ] # key --> attr_name -> value + ) = post_record.get( + "return_values", {} + ) # key --> attr_name -> value def __str__(self): return f"FuncCallEvent: {self.func_name}" From ac5202fbaed5ba0c061eb179ed0b1c4a8195a5ec Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 7 Jan 2026 17:58:52 -0500 Subject: [PATCH 08/15] fix: unify registry implementation for proxy and subclass --- .../proxy_wrapper/proxy_registry.py | 32 +++++++++---------- .../instrumentor/proxy_wrapper/subclass.py | 11 +++++-- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py index b8fbde11..abadf420 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py @@ -1,17 +1,13 @@ import threading -import typing - -from traincheck.utils import typename - -if typing.TYPE_CHECKING: - from .proxy import Proxy class RegistryEntry: - """A class to store the proxy object and its associated metadata""" + """A class to store the tracked object and its associated metadata""" - def __init__(self, proxy: "Proxy", stale: bool): - self.proxy = proxy + def __init__(self, obj, var_name, var_type, stale): + self.var = obj + self.var_name = var_name + self.var_type = var_type self.stale = stale @@ -30,14 +26,18 @@ def __init__(self): self.registry: dict[str, RegistryEntry] = {} self.registry_lock = threading.Lock() - def add_var(self, var: "Proxy", var_name: str): + def add_var(self, var, var_name: str, var_type: str): """Add a new proxy variable to the registry""" with self.registry_lock: if var_name in self.registry: - self.registry[var_name].proxy = var + self.registry[var_name].var = var + self.registry[var_name].var_name = var_name + self.registry[var_name].var_type = var_type self.registry[var_name].stale = False else: - self.registry[var_name] = RegistryEntry(proxy=var, stale=False) + self.registry[var_name] = RegistryEntry( + var, var_name, var_type, stale=False + ) def dump_sample(self, dump_loc=None): """A complete dump of all present proxy objects @@ -48,7 +48,7 @@ def dump_sample(self, dump_loc=None): with self.registry_lock: for _, entry in self.registry.items(): entry.stale = True - entry.proxy.dump_trace(phase="sample", dump_loc=dump_loc) + entry.var.dump_trace(phase="sample", dump_loc=dump_loc) def dump_modified(self, dump_loc=None, dump_config=None): """Dump only the proxy variables that might be modified since last dump @@ -73,8 +73,8 @@ def dump_modified(self, dump_loc=None, dump_config=None): """ to_dump_types = set(dump_config.keys()) with self.registry_lock: - for var_name, entry in self.registry.items(): - var_type = typename(entry.proxy._obj, is_runtime=True) + for _, entry in self.registry.items(): + var_type = entry.var_type if var_type not in to_dump_types: continue @@ -82,7 +82,7 @@ def dump_modified(self, dump_loc=None, dump_config=None): continue entry.stale = True - entry.proxy.dump_trace(phase="selective-sample", dump_loc=dump_loc) + entry.var.dump_trace(phase="selective-sample", dump_loc=dump_loc) if not dump_config[var_type]["dump_unchanged"]: # remove the var from to_dump_types so that we don't dump the same type twice to_dump_types.remove(var_type) diff --git a/traincheck/instrumentor/proxy_wrapper/subclass.py b/traincheck/instrumentor/proxy_wrapper/subclass.py index acb64a18..f1b9dada 100644 --- a/traincheck/instrumentor/proxy_wrapper/subclass.py +++ b/traincheck/instrumentor/proxy_wrapper/subclass.py @@ -9,7 +9,7 @@ from traincheck.instrumentor.dumper import dump_trace_VAR from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars from traincheck.instrumentor.tracer import TraceLineType -from traincheck.utils import get_timestamp_ns +from traincheck.utils import get_timestamp_ns, typename from .proxy_basics import is_fake_tensor from .proxy_registry import get_global_registry @@ -37,6 +37,7 @@ def __new__( # TODO # recurse=False, var_name="", + var_type="", should_dump_trace=True, from_call=False, from_iter=False, @@ -97,6 +98,7 @@ def __init__( # TODO # recurse=False, var_name="", + var_type="", should_dump_trace=True, from_call=False, from_iter=False, @@ -116,6 +118,7 @@ def __init__( # TODO # self.__dict__["recurse"] = recurse self.__dict__["var_name"] = var_name + self.__dict__["var_type"] = var_type # TODO # self.__dict__["old_value"] = None # self.__dict__["old_meta_vars"] = None @@ -165,7 +168,9 @@ def update_timestamp(self): # Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time def register_object(self): - get_global_registry().add_var(self, self.__dict__["var_name"]) + get_global_registry().add_var( + self, self.__dict__["var_name"], self.__dict__["var_type"] + ) def dump_trace(self, phase, dump_loc): # print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") @@ -216,11 +221,13 @@ def proxy_parameter( if in_dynamo(): return for name, t in list(module.named_parameters(recurse=False)): + var_type = typename(t, is_runtime=True) module._parameters[name] = ProxyParameter( t, logdir, log_level, parent_name + "." + name, + var_type, should_dump_trace, from_call, from_iter, From 33832757f7e2d129b0e775ace68769549c31162a Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 8 Jan 2026 20:43:49 -0500 Subject: [PATCH 09/15] fix: step incrementing logic --- traincheck/instrumentor/tracer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index c8988779..86280e2d 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -403,7 +403,7 @@ def wrapper( increment_step = False if original_function_name.endswith(".step"): owner = get_owner_class(original_function) - if isinstance(owner, torch.optim.Optimizer): + if issubclass(owner, torch.optim.Optimizer): increment_step = True # determine statically whether to dump the trace if not disable_dump: @@ -441,9 +441,12 @@ def wrapped(*args, **kwargs): return core_wrapper_proxy(original_function, *args, **kwargs) else: - if increment_step: - META_VARS["step"] += 1 - return original_function + + @functools.wraps(original_function) + def wrapped(*args, **kwargs): + if increment_step: + META_VARS["step"] += 1 + return original_function(*args, **kwargs) wrapped._traincheck_original_function = original_function wrapped._traincheck_instrumented = True From 03137575f7a28dc9a19b76ef3455ecf06f44b1fe Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 8 Jan 2026 20:44:38 -0500 Subject: [PATCH 10/15] add: richer error msg for unchanged var check in contain relation --- traincheck/invariant/contain_relation.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 7e81aee8..0c4d0e95 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -1102,6 +1102,9 @@ def static_check_all( # precondition passed inv_triggered = True + logger.info( + f"Performing unchanged var check ({skip_var_unchanged_check}) for the invariant: {inv.text_description} for the parent function: {parent_func_name} at {parent_func_call_id}" + ) if not skip_var_unchanged_check: assert isinstance( child_param, VarTypeParam @@ -1114,23 +1117,25 @@ def static_check_all( len(unchanged_var_ids) > 0 ), f"Internal error: can_func_be_bound_method returned True but no unchanged vars found for the parent function: {parent_func_name} at {parent_func_call_id}: {parent_pre_record['time']} at {trace.get_time_precentage(parent_pre_record['time'])}" # get the var change events for the unchanged vars - unchanged_var_states = [ - trace.get_var_raw_event_before_time( - var_id, parent_pre_record["time"] + unchanged_vars = [ + ( + var_id, + trace.get_var_raw_event_before_time( + var_id, parent_pre_record["time"] + ), ) for var_id in unchanged_var_ids ] - for unchanged_var_state in unchanged_var_states: + for var_id, unchanged_var_state in unchanged_vars: # verify that no precondition is met for the unchanged vars # MARK: precondition 2 if not preconditions.verify( unchanged_var_state, VAR_GROUP_NAME, trace ): logger.error( - f"INV CHECK ERROR: Precondition met for the unchanged vars for the parent function: {parent_func_name} at {parent_func_call_id}: {parent_pre_record['time']} at {trace.get_time_precentage(parent_pre_record['time'])}" + f"INV CHECK ERROR: Precondition met for the unchanged vars {var_id} for the parent function: {parent_func_name} at {parent_func_call_id}: {parent_pre_record['time']} at {trace.get_time_precentage(parent_pre_record['time'])}" ) var_unchanged_check_passed = False - break if (skip_var_unchanged_check and not found_expected_child_event) or ( not skip_var_unchanged_check and not var_unchanged_check_passed From d478e2704d92fe201ba96c0b6f1853fbd81de4c8 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 8 Jan 2026 20:45:11 -0500 Subject: [PATCH 11/15] fix: subclass registry updating process --- .../instrumentor/proxy_wrapper/proxy_registry.py | 16 +++++++++++----- .../instrumentor/proxy_wrapper/subclass.py | 2 ++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py index abadf420..2f466260 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py @@ -39,14 +39,18 @@ def add_var(self, var, var_name: str, var_type: str): var, var_name, var_type, stale=False ) - def dump_sample(self, dump_loc=None): + def dump_sample(self, dump_loc=None, dump_config=None): """A complete dump of all present proxy objects Calling this API mark all proxy objects as stale which will affect the `dump_modified` API. """ + to_dump_types = set(dump_config.keys()) with self.registry_lock: for _, entry in self.registry.items(): + var_type = entry.var_type + if var_type not in to_dump_types: + continue entry.stale = True entry.var.dump_trace(phase="sample", dump_loc=dump_loc) @@ -71,21 +75,23 @@ def dump_modified(self, dump_loc=None, dump_config=None): when calling the function, all dumped proxy vars will be marked as stale and will not be dumped next time unless there are new modification attempts to t """ + print("\nDumping from", dump_loc) to_dump_types = set(dump_config.keys()) with self.registry_lock: - for _, entry in self.registry.items(): + for var_name, entry in self.registry.items(): + print(f"var_name: {var_name}") var_type = entry.var_type if var_type not in to_dump_types: + print(" Skipping variable type:", var_type) continue if entry.stale: + print(" Skipping stale variable.") continue entry.stale = True entry.var.dump_trace(phase="selective-sample", dump_loc=dump_loc) - if not dump_config[var_type]["dump_unchanged"]: - # remove the var from to_dump_types so that we don't dump the same type twice - to_dump_types.remove(var_type) + print("Done dumping modified variables.") # Global dictionary to store registered objects diff --git a/traincheck/instrumentor/proxy_wrapper/subclass.py b/traincheck/instrumentor/proxy_wrapper/subclass.py index f1b9dada..c7292592 100644 --- a/traincheck/instrumentor/proxy_wrapper/subclass.py +++ b/traincheck/instrumentor/proxy_wrapper/subclass.py @@ -126,6 +126,7 @@ def __init__( current_time = get_timestamp_ns() self.__dict__["last_update_timestamp"] = current_time + self.register_object() # print(f"init: {self.var_name}") if should_dump_trace and not should_disable_proxy_dumping(): @@ -143,6 +144,7 @@ def __setattr__(self, name, value): # print(f"paremeter: {self.var_name}, name = {name}, value = {value}") super().__setattr__(name, value) self.update_timestamp() + self.register_object() if should_disable_proxy_dumping(): return self.dump_trace( From 6a9f694d8d3b0926c30f6f2836d231d408e96119 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Fri, 9 Jan 2026 23:22:10 -0500 Subject: [PATCH 12/15] fix: remove unproxy scanning for subclass to further reduce overhead --- traincheck/instrumentor/source_file.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 7eb4baad..4febb6d8 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -863,6 +863,12 @@ def instrument_file( "subclass", ], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler', 'subclass']" if model_tracker_style == "proxy" or model_tracker_style == "subclass": + if model_tracker_style == "subclass": + # adjust the proxy config to disable the proxy-specific configs + print( + "Using subclass model tracker, overriding observe_then_unproxy to False" + ) + adjusted_proxy_config[0]["observe_then_unproxy"] = False instrumented_source = instrument_model_tracker_proxy( instrumented_source, models_to_track, From 491ddb4da07e5a1edb2318e2084c17b6a79feb25 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Fri, 9 Jan 2026 23:23:45 -0500 Subject: [PATCH 13/15] add: refined logging for observer and registry --- traincheck/instrumentor/proxy_wrapper/proxy_observer.py | 7 +++++++ traincheck/instrumentor/proxy_wrapper/proxy_registry.py | 7 +------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_observer.py b/traincheck/instrumentor/proxy_wrapper/proxy_observer.py index a333000a..94116aa4 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_observer.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_observer.py @@ -9,8 +9,12 @@ from traincheck.instrumentor.proxy_wrapper.proxy import Proxy from traincheck.instrumentor.proxy_wrapper.subclass import ProxyParameter +import logging + from .proxy_basics import is_proxied, is_proxyparameter, unproxy_func +logger = logging.getLogger(__name__) + def observe_proxy_var( var: typing.Union["Proxy", "ProxyParameter"], @@ -22,6 +26,9 @@ def observe_proxy_var( var.update_timestamp() if phase == "post_observe": + logger.debug( + f"[ProxyObserver] Observing proxy var after {observe_api_name}: {var.__dict__['var_name']}" + ) var.register_object() if should_disable_proxy_dumping(): diff --git a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py index 2f466260..75c464d8 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy_registry.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy_registry.py @@ -75,23 +75,18 @@ def dump_modified(self, dump_loc=None, dump_config=None): when calling the function, all dumped proxy vars will be marked as stale and will not be dumped next time unless there are new modification attempts to t """ - print("\nDumping from", dump_loc) to_dump_types = set(dump_config.keys()) with self.registry_lock: - for var_name, entry in self.registry.items(): - print(f"var_name: {var_name}") + for _, entry in self.registry.items(): var_type = entry.var_type if var_type not in to_dump_types: - print(" Skipping variable type:", var_type) continue if entry.stale: - print(" Skipping stale variable.") continue entry.stale = True entry.var.dump_trace(phase="selective-sample", dump_loc=dump_loc) - print("Done dumping modified variables.") # Global dictionary to store registered objects From ab95372712a1cb690204833bf61f2e30ec51987d Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Fri, 9 Jan 2026 23:25:32 -0500 Subject: [PATCH 14/15] fix: selective dumping for the proxy class --- traincheck/instrumentor/proxy_wrapper/proxy.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/traincheck/instrumentor/proxy_wrapper/proxy.py b/traincheck/instrumentor/proxy_wrapper/proxy.py index 3d48bbe5..5ce745ab 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy.py @@ -8,17 +8,16 @@ import torch -import traincheck.config.config as general_config import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables import traincheck.instrumentor.proxy_wrapper.proxy_methods as proxy_methods +from traincheck.config.config import should_disable_proxy_dumping from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars from traincheck.utils import get_timestamp_ns, typename from .dumper import json_dumper as dumper from .proxy_basics import unproxy_arg, unproxy_args_kwargs from .proxy_handler import PROXY_SUPPORT_OBJ_TYPES - -# from .proxy_registry import get_global_registry +from .proxy_registry import get_global_registry from .utils import print_debug @@ -130,7 +129,9 @@ def update_timestamp(self): Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time def register_object(self): - # get_global_registry().add_var(self, self.__dict__["var_name"]) + get_global_registry().add_var( + self, self.__dict__["var_name"], self.__dict__["var_type"] + ) # TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object pass @@ -207,6 +208,7 @@ def __init__( self.__dict__["is_traincheck_proxied_obj"] = True self.__dict__["recurse"] = recurse self.__dict__["var_name"] = var_name + self.__dict__["var_type"] = typename(obj, is_runtime=True) self.__dict__["old_value"] = None self.__dict__["old_meta_vars"] = None @@ -226,6 +228,7 @@ def __init__( ] self.__dict__["recurse"] = obj.__dict__["recurse"] self.__dict__["var_name"] = obj.__dict__["var_name"] + self.__dict__["var_type"] = obj.__dict__["var_type"] self.__dict__["logdir"] = obj.__dict__["logdir"] self.__dict__["log_level"] = obj.__dict__["log_level"] self.__dict__["meta_vars"] = obj.__dict__["meta_vars"] @@ -261,7 +264,7 @@ def __init__( if not dump_iter and from_iter: return - if should_dump_trace: + if should_dump_trace and not should_disable_proxy_dumping(): if from_call: phase = "call" @@ -363,7 +366,7 @@ def __setattr__(self, name, value): ), ) - if general_config.should_disable_proxy_dumping(): + if should_disable_proxy_dumping(): # do not dump update traces return None From 53d90cab8c3f03b59aea3f98b39c689e8fd0ecce Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Fri, 9 Jan 2026 23:26:34 -0500 Subject: [PATCH 15/15] add: monkey patch __setattr__ at the module level when using subclass, to ensure submodule assignments are captured --- .../instrumentor/proxy_wrapper/subclass.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/traincheck/instrumentor/proxy_wrapper/subclass.py b/traincheck/instrumentor/proxy_wrapper/subclass.py index c7292592..5f40f796 100644 --- a/traincheck/instrumentor/proxy_wrapper/subclass.py +++ b/traincheck/instrumentor/proxy_wrapper/subclass.py @@ -1,3 +1,4 @@ +import functools import logging import os import threading @@ -14,7 +15,9 @@ from .proxy_basics import is_fake_tensor from .proxy_registry import get_global_registry -# from .utils import print_debug +SUBCLASS_HOOK_KEY = "_tc_setattr_hook" + +logger = logging.getLogger(__name__) def in_dynamo() -> bool: @@ -126,9 +129,9 @@ def __init__( current_time = get_timestamp_ns() self.__dict__["last_update_timestamp"] = current_time + logger.debug(f"[ProxyParameter] Created ProxyParameter: {self.var_name}") self.register_object() - # print(f"init: {self.var_name}") if should_dump_trace and not should_disable_proxy_dumping(): if from_call: phase = "call" @@ -141,7 +144,7 @@ def __init__( self.dump_trace(phase=phase, dump_loc="initing") def __setattr__(self, name, value): - # print(f"paremeter: {self.var_name}, name = {name}, value = {value}") + super().__setattr__(name, value) self.update_timestamp() self.register_object() @@ -175,7 +178,6 @@ def register_object(self): ) def dump_trace(self, phase, dump_loc): - # print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") # TODO var_name = self.__dict__["var_name"] # assert var_name is not None # '' is allowed as a var_name (root object) @@ -224,6 +226,9 @@ def proxy_parameter( return for name, t in list(module.named_parameters(recurse=False)): var_type = typename(t, is_runtime=True) + logger.debug( + f"[ProxyParameter] Proxying parameter: {parent_name}.{name} of type {var_type}" + ) module._parameters[name] = ProxyParameter( t, logdir, @@ -244,3 +249,37 @@ def proxy_parameter( from_call, from_iter, ) + + # we need to instrument the __setattr__ of the module to capture parameter updates + def subclass_setattr_hook(self, name, value): + logger.debug( + f"[ProxyParameter] Module __setattr__ called: {parent_name}.{name} = {type(value)}" + ) + if isinstance(value, torch.Tensor) or isinstance(value, torch.nn.Module): + proxy_parameter( + value, + logdir, + log_level, + parent_name + "." + name, + should_dump_trace, + from_call, + from_iter, + ) + + module.__dict__[SUBCLASS_HOOK_KEY] = subclass_setattr_hook + + +# instrument torch.nn.Module's setattr +orig_setattr = torch.nn.Module.__setattr__ + + +@functools.wraps(orig_setattr) +def wrapped_setattr(self, name, value): + hook = getattr(self, SUBCLASS_HOOK_KEY, None) + if hook is not None: + # If hook returns True, skip the original setattr; otherwise continue. + hook(self, name, value) + return orig_setattr(self, name, value) + + +torch.nn.Module.__setattr__ = wrapped_setattr