From e8e67a9e5966736a35ca2c2ca05c8d683d321a79 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 19 Jan 2026 09:05:27 -0800 Subject: [PATCH 1/8] Neat cleanups Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/distillation_model.py | 62 +++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/modelopt/torch/distill/distillation_model.py b/modelopt/torch/distill/distillation_model.py index 930b68560..ed886909f 100644 --- a/modelopt/torch/distill/distillation_model.py +++ b/modelopt/torch/distill/distillation_model.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - """Meta-model wrapper to support knowledge-distillation learning.""" import inspect @@ -45,6 +43,7 @@ def _setup(self): self._register_temp_attribute("_loss_modules", nn.ModuleList()) self._register_temp_attribute("_only_teacher_fwd", False) self._register_temp_attribute("_only_student_fwd", False) + self._register_temp_attribute("_hook_handles", set()) # HACK: set model's forward signature to match student class' original. # Needed for HF `transformers.utils.find_labels` which relies on inspecting class signature. @@ -57,13 +56,13 @@ def _setup(self): def modify( self, - teacher_model: nn.Module, # To be frozen. + teacher_model: nn.Module, criterion: dict[ tuple[ - str, # Student model layer whose output to capture. - str, # Teacher model layer whose output to capture. + str, # Student model layer whose output to capture + str, # Teacher model layer whose output to capture ], - Loss, # Loss fn. + Loss, # Loss function ], loss_balancer: DistillationLossBalancer | None = None, expose_minimal_state_dict: bool = True, @@ -71,9 +70,8 @@ def modify( """Constructor. Args: - teacher_model: A teacher model which this class would encapsulate. - criterion: A dictionary mapping the tuple of student and teacher - model layer names to the loss function to apply to that layer pair. + teacher_model: The teacher model (will be frozen). + criterion: Dictionary mapping (student_layer_name, teacher_layer_name) to loss functions. loss_balancer: Instance of :class:`DistillationLossBalancer ` which reduces distillation and non-distillation losses into a single value using some weighing scheme. @@ -106,22 +104,30 @@ def modify( {m for m in self._layers_to_loss.values() if len(list(m.parameters())) > 0} ) - # Disable grad for teacher + # Disable grad for teacher. self._teacher_model.requires_grad_(False) - # Register hooks for intermediate outputs from teacher models and the student model. - # HACK: For inexplicable reasons, sometimes a model will have hooks remain after - # `ato.restore()` so we check if they are present accidentally first. + # Use hooks to caputure relevant activation tensors for loss computation. + self._register_hooks() + + def _register_hooks(self): + """Register hooks for intermediate tensors from teacher models and the student model.""" for student_layer, teacher_layer in self._layers_to_loss: setattr(student_layer, "_intermediate_output", None) - if student_output_capture_fwd_hook not in student_layer._forward_hooks.values(): - student_layer.register_forward_hook(student_output_capture_fwd_hook) + handle_s = student_layer.register_forward_hook(student_output_capture_fwd_hook) setattr(teacher_layer, "_intermediate_output", None) - if teacher_output_capture_fwd_hook not in teacher_layer._forward_hooks.values(): - teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook) + handle_t = teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook) + self._hook_handles.extend([handle_s, handle_t]) + + def export(self): + """Export the distillation model.""" + for handle in self._hook_handles: + handle.remove() + self._hook_handles.clear() + return super().export() @property - def teacher_model(self) -> nn.ModuleList: + def teacher_model(self) -> nn.Module: """Fetch the teacher model.""" return self._teacher_model @@ -148,7 +154,7 @@ def hide_teacher_model(self, enable=True): @contextmanager def hide_loss_modules(self, enable=True): - """Context manager to temporarily hide teacher model from the model.""" + """Context manager to temporarily hide loss modules from the model.""" loss_modules = self._loss_modules if enable: self._loss_modules = nn.ModuleList() @@ -169,7 +175,7 @@ def only_teacher_forward(self, enable=True): @contextmanager def only_student_forward(self, enable=True): - """Context manager to temporarily disable forward passes on the student model.""" + """Context manager to temporarily run forward passes only on the student model.""" if enable: self._only_student_fwd = True try: @@ -245,15 +251,13 @@ def compute_kd_loss( Args: student_loss: Original loss computed from the student's output. - loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for - loss-masking situations where the callable changes arguments each iteration. + loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. + Useful for loss-masking situations where the callable changes arguments each iteration. skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar. **loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed. - This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``. Returns: - If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses. - If reduce is False, a dict of student model output loss and layer-wise distillation losses. + A dict of losses if skip_balancer is True, else the scalar total loss. """ if self._loss_balancer is None: assert student_loss is None, "Cannot pass in student loss without using Loss Balancer." @@ -288,9 +292,9 @@ def compute_kd_loss( return loss_total -def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin +def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): """A hook to capture layer output.""" - # NOTE: Defined externally to allow pickling. + # NOTE: Defined externally to allow pickling during DDP initialization. if getattr(module, "_only_teacher_fwd", False): return # Might be hooked on entire model fwd @@ -303,9 +307,9 @@ def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): module._intermediate_output = output -def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin +def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): """A hook to capture layer output.""" - # NOTE: Defined externally to allow pickling. + # NOTE: Defined externally to allow pickling during DDP initialization. if module._intermediate_output is not None: # NOTE: cannot tell if train or eval since teacher is always eval From 576b9a75c77cd72bde1be79d12f8a1dc4c6ae0cd Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 20 Jan 2026 10:13:27 -0800 Subject: [PATCH 2/8] New bypass mode and model Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/__init__.py | 1 + .../distill/bypass_distillation_model.py | 82 +++++++++++++++++++ modelopt/torch/distill/config.py | 19 ++++- modelopt/torch/distill/distillation_model.py | 2 +- modelopt/torch/distill/mode.py | 37 ++++++++- 5 files changed, 136 insertions(+), 5 deletions(-) create mode 100644 modelopt/torch/distill/bypass_distillation_model.py diff --git a/modelopt/torch/distill/__init__.py b/modelopt/torch/distill/__init__.py index a09aa6b8e..49b314522 100644 --- a/modelopt/torch/distill/__init__.py +++ b/modelopt/torch/distill/__init__.py @@ -16,6 +16,7 @@ """Distillation API subpackage for torch.""" from . import mode +from .bypass_distillation_model import * from .config import * from .distillation import * from .distillation_model import * diff --git a/modelopt/torch/distill/bypass_distillation_model.py b/modelopt/torch/distill/bypass_distillation_model.py new file mode 100644 index 000000000..ee9e3bbf1 --- /dev/null +++ b/modelopt/torch/distill/bypass_distillation_model.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Meta-model wrapper to support bypass-enabled knowledge-distillation learning.""" + +import warnings +from typing import Any + +import torch.nn as nn + +from .distillation_model import DistillationModel, student_output_capture_fwd_hook + +__all__ = ["BypassDistillationModel"] + + +class BypassDistillationModel(DistillationModel): + """Meta-model wrapper to support bypass-enabled knowledge-distillation learning.""" + + def _register_hooks(self): + """Register hooks for intermediate tensors from teacher models and the student model.""" + for student_layer, teacher_layer in self._layers_to_loss: + setattr(student_layer, "_teacher_layer", [teacher_layer]) + handle_s1 = student_layer.register_forward_pre_hook(student_input_bypass_fwd_hook) + setattr(student_layer, "_intermediate_output", None) + handle_s2 = student_layer.register_forward_hook(student_output_capture_fwd_hook) + setattr(teacher_layer, "_intermediate_input", None) + setattr(teacher_layer, "_intermediate_output", None) + handle_t = teacher_layer.register_forward_hook(teacher_input_output_capture_fwd_hook) + self._hook_handles.update([handle_s1, handle_s2, handle_t]) + + def export(self): + """Export the distillation model.""" + for student_layer, _ in self._layers_to_loss: + delattr(student_layer, "_teacher_layer") + return super().export() + + +def student_input_bypass_fwd_hook(module: nn.Module, input: Any): + """A hook to inject teacher input into corresponding student layer.""" + # NOTE: Defined externally to allow pickling during DDP initialization. + + if getattr(module, "_only_teacher_fwd", False): + return input # Might be hooked on entire model fwd + + teacher_layer = module._teacher_layer[0] + teacher_input = teacher_layer._intermediate_input + if teacher_input is None: + warnings.warn( + f"Teacher's Module `{type(teacher_layer).__name__}` has no intermediate input stored." + " This is expected when the `only_student_forward` context manager is in use." + ) + return input + + teacher_layer._intermediate_input = None # reset + return teacher_input + + +def teacher_input_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): + """A hook to capture layer input and output.""" + # NOTE: Defined externally to allow pickling during DDP initialization. + + if module._intermediate_output is not None: + # NOTE: cannot tell if train or eval since teacher is always eval + warnings.warn( + f"Teacher's Module `{type(module).__name__}` already has an intermediate output stored." + " This is expected when `DistillationModel.compute_kd_loss` is not called in eval mode." + ) + + module._intermediate_input = input + module._intermediate_output = output diff --git a/modelopt/torch/distill/config.py b/modelopt/torch/distill/config.py index cfdb3ccb6..f8ef5a883 100644 --- a/modelopt/torch/distill/config.py +++ b/modelopt/torch/distill/config.py @@ -26,7 +26,7 @@ from .loss_balancers import DistillationLossBalancer -__all__ = ["KDLossConfig"] +__all__ = ["BypassKDConfig", "ExportStudentConfig", "KDLossConfig"] Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007 @@ -120,6 +120,23 @@ def _strict_validate(self) -> None: ) +class BypassKDConfig(KDLossConfig): + """Configuration for the Bypass Knowledge-Distillation mode. + + This mode is used to distill knowledge from a teacher model to a student model using bypassing. + """ + + @pydantic.field_validator("criterion") + @classmethod + def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]: + """Ensure criterion is a mapping from layer names to loss (potentially entire module).""" + if not isinstance(criterion, dict): + raise ValueError("Bypass Distillation mode requires explicit criterion pairs.") + if any(key == ("", "") for key in criterion): + raise ValueError("Bypass Distillation mode does not support output-only distillation.") + return criterion + + class ExportStudentConfig(ModeloptBaseConfig): """Configuration for the export_student mode. diff --git a/modelopt/torch/distill/distillation_model.py b/modelopt/torch/distill/distillation_model.py index ed886909f..fa344385a 100644 --- a/modelopt/torch/distill/distillation_model.py +++ b/modelopt/torch/distill/distillation_model.py @@ -117,7 +117,7 @@ def _register_hooks(self): handle_s = student_layer.register_forward_hook(student_output_capture_fwd_hook) setattr(teacher_layer, "_intermediate_output", None) handle_t = teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook) - self._hook_handles.extend([handle_s, handle_t]) + self._hook_handles.update([handle_s, handle_t]) def export(self): """Export the distillation model.""" diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index 75ea751f4..6840b91c8 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -36,7 +36,8 @@ ) from modelopt.torch.utils import init_model_from_model_like, unwrap_model -from .config import ExportStudentConfig, KDLossConfig +from .bypass_distillation_model import BypassDistillationModel +from .config import BypassKDConfig, ExportStudentConfig, KDLossConfig from .distillation_model import DistillationModel from .registry import DistillationDMRegistry @@ -86,6 +87,29 @@ def save_mode_in_state(self) -> bool: return False +@DistillModeRegistry.register_mode +class BypassKDModeDescriptor(KnowledgeDistillationModeDescriptor): + """Class to describe the Bypass Knowledge-Distillation mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "bypass_kd" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return BypassKDConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return _convert_for_bypass + + @DistillModeRegistry.register_mode class ExportStudentModeDescriptor(ModeDescriptor): """Class to describe the specific Export mode to be used with Knowledge Distillation. @@ -124,7 +148,9 @@ def save_mode_in_state(self) -> bool: return False -def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType: +def _convert_for_kd( + model: nn.Module, config: KDLossConfig, model_cls: type[nn.Module] = DistillationModel +) -> ConvertReturnType: """Function for converting a model to a distillation meta-model. This is the only utility needed to use the ``modelopt.torch.distill`` API directly. @@ -159,7 +185,7 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType # initialize distillation model original_cls = type(student) if original_cls not in DistillationDMRegistry: - DistillationDMRegistry.register({original_cls: "student_class"})(DistillationModel) + DistillationDMRegistry.register({original_cls: "student_class"})(model_cls) # TODO (lucasl): look into ways to avoid registering every class manually # (e.g. by just registering nn.Module and disable the "forward" check for the inherited class check @@ -174,6 +200,11 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType return distillation_model, metadata +def _convert_for_bypass(model: nn.Module, config: BypassKDConfig) -> ConvertReturnType: + """Function for converting a model to a bypass distillation meta-model.""" + return _convert_for_kd(model, config, model_cls=BypassDistillationModel) + + def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): """Function for resetting the state's config.""" config.teacher_model = nn.Module From 2173ec7b3f246e9025a815863adbcf2c89bc6b3c Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 21 Jan 2026 06:55:43 -0800 Subject: [PATCH 3/8] Add tests and fix bugs Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/mode.py | 15 -- tests/unit/torch/distill/test_bypass.py | 226 +++++++++++++++++++++++ tests/unit/torch/distill/test_distill.py | 10 +- 3 files changed, 227 insertions(+), 24 deletions(-) create mode 100644 tests/unit/torch/distill/test_bypass.py diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index 6840b91c8..ae36094ab 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -21,17 +21,14 @@ import warnings import torch.nn as nn -from torch.nn.modules.loss import _Loss as Loss from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.opt.conversion import ModeloptStateManager from modelopt.torch.opt.mode import ( ConvertEntrypoint, ConvertReturnType, - MetadataDict, ModeDescriptor, RestoreEntrypoint, - UpdateEntrypoint, _ModeRegistryCls, ) from modelopt.torch.utils import init_model_from_model_like, unwrap_model @@ -76,11 +73,6 @@ def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" raise NotImplementedError(f"{self.name} mode does not support restore.") - @property - def update_for_new_mode(self) -> UpdateEntrypoint: - """The mode's entrypoint for updating the models state for adding new mode.""" - return _reset_kd_state_config - @property def save_mode_in_state(self) -> bool: """Whether the mode should be saved into the modelopt state.""" @@ -205,13 +197,6 @@ def _convert_for_bypass(model: nn.Module, config: BypassKDConfig) -> ConvertRetu return _convert_for_kd(model, config, model_cls=BypassDistillationModel) -def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): - """Function for resetting the state's config.""" - config.teacher_model = nn.Module - config.criterion = Loss() - config.loss_balancer = None - - def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertReturnType: """Export a ``DistillationModel`` to its inner student model, including modelopt state transfer.""" if not isinstance(model, DistillationModel): diff --git a/tests/unit/torch/distill/test_bypass.py b/tests/unit/torch/distill/test_bypass.py new file mode 100644 index 000000000..57de35667 --- /dev/null +++ b/tests/unit/torch/distill/test_bypass.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import pytest +import torch +from _test_utils.torch.vision_models import get_tiny_mobilenet_and_input + +import modelopt.torch.distill as mtd +import modelopt.torch.opt as mto + + +def get_input_tensor(): + """Dummy input tensor.""" + return torch.rand(2, 3, 112, 112) + + +def tiny_mobilenet(): + return get_tiny_mobilenet_and_input()[0] + + +@pytest.fixture +def bypass_distillation_model(): + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.2", "features.2"): torch.nn.MSELoss(), + }, + "loss_balancer": mtd.StaticLossBalancer(), + } + bypass_model = mtd.convert(student, mode=[("bypass_kd", config)]) + + return bypass_model + + +def test_bypass_hooks_registration(bypass_distillation_model): + """Test that bypass-specific hooks are registered correctly.""" + # Check that student layers have _teacher_layer attribute + for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + assert hasattr(student_layer, "_teacher_layer") + assert student_layer._teacher_layer[0] is teacher_layer + assert hasattr(student_layer, "_intermediate_output") + + # Check that teacher layers have both input and output capture attributes + assert hasattr(teacher_layer, "_intermediate_input") + assert hasattr(teacher_layer, "_intermediate_output") + + +def test_bypass_forward_pass(bypass_distillation_model): + """Test that forward pass works and captures both teacher inputs and outputs.""" + bypass_distillation_model.train() + input_tensor = get_input_tensor() + + bypass_distillation_model(input_tensor) + + # Check that teacher intermediate inputs and outputs are captured + for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is None + assert teacher_layer._intermediate_output is not None + assert student_layer._intermediate_output is not None + + +def test_bypass_input_injection(bypass_distillation_model): + """Test that teacher inputs are injected into student layers during bypass.""" + bypass_distillation_model.train() + input_tensor = get_input_tensor() + + # Perform forward pass + bypass_distillation_model(input_tensor) + + # Verify that teacher inputs were captured (they should be reset after injection) + # After forward, teacher inputs should have been consumed by student layers + for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + # After full forward pass, teacher_layer._intermediate_input should be None + # because it gets consumed by the student bypass hook + assert teacher_layer._intermediate_input is None + + +def test_bypass_loss_computation(bypass_distillation_model): + """Test that loss computation works with bypass distillation.""" + bypass_distillation_model.train() + input_tensor = get_input_tensor() + + output = bypass_distillation_model(input_tensor) + loss = bypass_distillation_model.compute_kd_loss(student_loss=output.mean()) + + assert isinstance(loss, torch.Tensor) + assert loss.numel() == 1 + assert loss.requires_grad + + +def test_bypass_only_student_forward(bypass_distillation_model): + """Test that only_student_forward context manager works with bypass.""" + bypass_distillation_model.train() + input_tensor = get_input_tensor() + + # When using only_student_forward, teacher inputs should not be captured + with warnings.catch_warnings(record=True) as w: + with bypass_distillation_model.only_student_forward(): + bypass_distillation_model(input_tensor) + + # Should get warning about missing teacher input + warning_messages = [str(warning.message) for warning in w] + assert any("has no intermediate input stored" in msg for msg in warning_messages) + + # Verify teacher didn't run + for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is None + assert teacher_layer._intermediate_output is None + assert student_layer._intermediate_output is not None + + +def test_bypass_only_teacher_forward(bypass_distillation_model): + """Test that only_teacher_forward context manager works with bypass.""" + bypass_distillation_model.train() + input_tensor = get_input_tensor() + + with bypass_distillation_model.only_teacher_forward(): + bypass_distillation_model(input_tensor) + + # Verify teacher ran and student didn't + for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is not None + assert teacher_layer._intermediate_output is not None + assert student_layer._intermediate_output is None + + +def test_bypass_export(bypass_distillation_model): + """Test that export correctly cleans up bypass-specific attributes.""" + # Check that _teacher_layer exists before export + for student_layer, _ in bypass_distillation_model._layers_to_loss: + assert hasattr(student_layer, "_teacher_layer") + + # Export the model + exported_model = mtd.export(bypass_distillation_model) + + # Check that _teacher_layer is removed after export + for student_layer in exported_model.modules(): + assert not hasattr(student_layer, "_teacher_layer") + + assert not hasattr(exported_model, "_teacher_model") + assert not isinstance(exported_model, mtd.BypassDistillationModel) + + +def test_bypass_save_restore(bypass_distillation_model, tmp_path): + """Test that save/restore works correctly with bypass distillation.""" + mto.save(bypass_distillation_model, tmp_path / "ckpt.pt") + + new_student = tiny_mobilenet() + restored_model = mto.restore(new_student, tmp_path / "ckpt.pt") + + # Ensure state is not actually restored (expected behavior from test_distill.py) + manager = mto.ModeloptStateManager(restored_model) + assert not manager.has_state + assert isinstance(restored_model, type(new_student)) + + +def test_bypass_multiloss(): + """Test bypass distillation with multiple loss functions.""" + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.1", "features.1"): torch.nn.MSELoss(), + ("features.3", "features.3"): torch.nn.MSELoss(), + }, + "loss_balancer": mtd.StaticLossBalancer([0.5, 0.5]), + } + bypass_model = mtd.convert(student, mode=[("bypass_kd", config)]) + + # Verify hooks are registered for all layers + assert len(bypass_model._layers_to_loss) == 2 + + # Test forward pass + output = bypass_model(get_input_tensor()) + loss = bypass_model.compute_kd_loss(student_loss=output.mean()) + + assert isinstance(loss, torch.Tensor) + assert loss.numel() == 1 + + +def test_bypass_gradient_flow(): + """Test that gradients flow correctly through bypass distillation.""" + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.2", "features.2"): torch.nn.MSELoss(), + }, + "loss_balancer": None, + } + bypass_model = mtd.convert(student, mode=[("bypass_kd", config)]) + + optimizer = torch.optim.SGD(bypass_model.parameters(), lr=0.01) + + # Get initial parameter values + initial_params = [p.clone() for p in bypass_model.parameters() if p.requires_grad] + + # Training step + optimizer.zero_grad() + bypass_model(get_input_tensor()) + loss = bypass_model.compute_kd_loss() + loss.backward() + optimizer.step() + + # Check that at least some parameters changed + current_params = [p for p in bypass_model.parameters() if p.requires_grad] + param_changed = any( + not torch.allclose(initial, current) + for initial, current in zip(initial_params, current_params) + ) + assert param_changed, "No parameters were updated during training" diff --git a/tests/unit/torch/distill/test_distill.py b/tests/unit/torch/distill/test_distill.py index 10241f076..69dec86b7 100644 --- a/tests/unit/torch/distill/test_distill.py +++ b/tests/unit/torch/distill/test_distill.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn from _test_utils.torch.vision_models import get_tiny_mobilenet_and_input -from torch.nn.modules.loss import _Loss as Loss from torchvision.models import alexnet import modelopt.torch.distill as mtd @@ -37,7 +36,7 @@ def tiny_mobilenet(): def tiny_alexnet(): - return alexnet(num_classes=10) # Same class as tiny_mobilenet + return alexnet(num_classes=10) # same num classes as tiny_mobilenet @pytest.fixture @@ -168,13 +167,6 @@ def test_distillation_export(distillation_model, tmp_path): assert not hasattr(model_exported, "_teacher_model") assert hasattr(model_exported, mto.ModeloptStateManager._state_key) - # Test if kd_loss config has been cleaned up - manager = mto.ModeloptStateManager(model_exported) - cfg = manager._state[-2][1]["config"] - assert cfg["teacher_model"] == nn.Module - assert isinstance(next(iter(cfg["criterion"].values())), Loss) - assert cfg["loss_balancer"] is None - mto.save(model_exported, tmp_path / "ckpt.pt") new_student = tiny_mobilenet() new_student_restored = mto.restore(new_student, tmp_path / "ckpt.pt") From 59d12e6c80f94dc0708751141288796d989cdd17 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 21 Jan 2026 08:03:52 -0800 Subject: [PATCH 4/8] Refine gradient test Signed-off-by: Asha Anoosheh --- tests/unit/torch/distill/test_bypass.py | 30 ++++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/unit/torch/distill/test_bypass.py b/tests/unit/torch/distill/test_bypass.py index 57de35667..291f7733f 100644 --- a/tests/unit/torch/distill/test_bypass.py +++ b/tests/unit/torch/distill/test_bypass.py @@ -205,22 +205,30 @@ def test_bypass_gradient_flow(): } bypass_model = mtd.convert(student, mode=[("bypass_kd", config)]) - optimizer = torch.optim.SGD(bypass_model.parameters(), lr=0.01) - - # Get initial parameter values - initial_params = [p.clone() for p in bypass_model.parameters() if p.requires_grad] + # Save param snapshots by module + param_snapshots = { + name: p.clone() for name, p in bypass_model.named_parameters() if p.requires_grad + } - # Training step + # Forward and backward + optimizer = torch.optim.SGD(bypass_model.parameters(), lr=0.5) optimizer.zero_grad() bypass_model(get_input_tensor()) loss = bypass_model.compute_kd_loss() loss.backward() optimizer.step() - # Check that at least some parameters changed - current_params = [p for p in bypass_model.parameters() if p.requires_grad] - param_changed = any( - not torch.allclose(initial, current) - for initial, current in zip(initial_params, current_params) + # Check: parameters in only the target layer(s) are changed + updated_any = False + for name, param in bypass_model.named_parameters(): + if not param.requires_grad: + continue + changed = not torch.allclose(param, param_snapshots[name]) + if "features.2" in name: + assert changed, f"'{name}' parameters did not change!" + updated_any = True + else: + assert not changed, f"Parameters in unrelated layer '{name}' changed!" + assert updated_any, ( + "No parameters were updated in 'features.2' or related layers during training" ) - assert param_changed, "No parameters were updated during training" From cfd9d66c72b93f0962ba9f8e5985007a2d0d2faa Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 21 Jan 2026 08:08:57 -0800 Subject: [PATCH 5/8] Expand bypass model docstring Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/bypass_distillation_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/distill/bypass_distillation_model.py b/modelopt/torch/distill/bypass_distillation_model.py index ee9e3bbf1..ccf6c64a8 100644 --- a/modelopt/torch/distill/bypass_distillation_model.py +++ b/modelopt/torch/distill/bypass_distillation_model.py @@ -26,7 +26,13 @@ class BypassDistillationModel(DistillationModel): - """Meta-model wrapper to support bypass-enabled knowledge-distillation learning.""" + """Meta-model wrapper to support bypass-enabled knowledge-distillation learning. + + The BypassDistillationModel is a subclass of the DistillationModel that injects teacher inputs + into the corresponding student layers. This accomodates the case where the student model is the + teacher with specific submodules replaced, which now need to be trained to mimic the original + submodule in the teacher. + """ def _register_hooks(self): """Register hooks for intermediate tensors from teacher models and the student model.""" From 63e24c5c054a52507032ada890a9b4a8d42c5f17 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 22 Jan 2026 07:06:40 -0800 Subject: [PATCH 6/8] Rename Bypass to Layerwise Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/__init__.py | 2 +- modelopt/torch/distill/config.py | 14 ++- ...del.py => layerwise_distillation_model.py} | 10 +- modelopt/torch/distill/mode.py | 20 ++-- .../{test_bypass.py => test_layerwise.py} | 110 +++++++++--------- 5 files changed, 79 insertions(+), 77 deletions(-) rename modelopt/torch/distill/{bypass_distillation_model.py => layerwise_distillation_model.py} (90%) rename tests/unit/torch/distill/{test_bypass.py => test_layerwise.py} (63%) diff --git a/modelopt/torch/distill/__init__.py b/modelopt/torch/distill/__init__.py index 49b314522..dad15dcc6 100644 --- a/modelopt/torch/distill/__init__.py +++ b/modelopt/torch/distill/__init__.py @@ -16,10 +16,10 @@ """Distillation API subpackage for torch.""" from . import mode -from .bypass_distillation_model import * from .config import * from .distillation import * from .distillation_model import * +from .layerwise_distillation_model import * from .loss_balancers import * from .losses import * from .registry import * diff --git a/modelopt/torch/distill/config.py b/modelopt/torch/distill/config.py index f8ef5a883..74ef15300 100644 --- a/modelopt/torch/distill/config.py +++ b/modelopt/torch/distill/config.py @@ -26,7 +26,7 @@ from .loss_balancers import DistillationLossBalancer -__all__ = ["BypassKDConfig", "ExportStudentConfig", "KDLossConfig"] +__all__ = ["ExportStudentConfig", "KDLossConfig", "LayerwiseKDConfig"] Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007 @@ -120,10 +120,10 @@ def _strict_validate(self) -> None: ) -class BypassKDConfig(KDLossConfig): - """Configuration for the Bypass Knowledge-Distillation mode. +class LayerwiseKDConfig(KDLossConfig): + """Configuration for the Layerwise Knowledge-Distillation mode. - This mode is used to distill knowledge from a teacher model to a student model using bypassing. + This mode is used to distill knowledge from a teacher model to a student model using layerwise distillation. """ @pydantic.field_validator("criterion") @@ -131,9 +131,11 @@ class BypassKDConfig(KDLossConfig): def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]: """Ensure criterion is a mapping from layer names to loss (potentially entire module).""" if not isinstance(criterion, dict): - raise ValueError("Bypass Distillation mode requires explicit criterion pairs.") + raise ValueError("Layerwise Distillation mode requires explicit criterion pairs.") if any(key == ("", "") for key in criterion): - raise ValueError("Bypass Distillation mode does not support output-only distillation.") + raise ValueError( + "Layerwise Distillation mode does not support output-only distillation." + ) return criterion diff --git a/modelopt/torch/distill/bypass_distillation_model.py b/modelopt/torch/distill/layerwise_distillation_model.py similarity index 90% rename from modelopt/torch/distill/bypass_distillation_model.py rename to modelopt/torch/distill/layerwise_distillation_model.py index ccf6c64a8..9e91a3d85 100644 --- a/modelopt/torch/distill/bypass_distillation_model.py +++ b/modelopt/torch/distill/layerwise_distillation_model.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Meta-model wrapper to support bypass-enabled knowledge-distillation learning.""" +"""Meta-model wrapper to support layerwise-enabled knowledge-distillation learning.""" import warnings from typing import Any @@ -22,13 +22,13 @@ from .distillation_model import DistillationModel, student_output_capture_fwd_hook -__all__ = ["BypassDistillationModel"] +__all__ = ["LayerwiseDistillationModel"] -class BypassDistillationModel(DistillationModel): - """Meta-model wrapper to support bypass-enabled knowledge-distillation learning. +class LayerwiseDistillationModel(DistillationModel): + """Meta-model wrapper to support layerwise-enabled knowledge-distillation learning. - The BypassDistillationModel is a subclass of the DistillationModel that injects teacher inputs + The LayerwiseDistillationModel is a subclass of the DistillationModel that injects teacher inputs into the corresponding student layers. This accomodates the case where the student model is the teacher with specific submodules replaced, which now need to be trained to mimic the original submodule in the teacher. diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index ae36094ab..15467c761 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -33,9 +33,9 @@ ) from modelopt.torch.utils import init_model_from_model_like, unwrap_model -from .bypass_distillation_model import BypassDistillationModel -from .config import BypassKDConfig, ExportStudentConfig, KDLossConfig +from .config import ExportStudentConfig, KDLossConfig, LayerwiseKDConfig from .distillation_model import DistillationModel +from .layerwise_distillation_model import LayerwiseDistillationModel from .registry import DistillationDMRegistry DistillModeRegistry = _ModeRegistryCls("distill") @@ -80,8 +80,8 @@ def save_mode_in_state(self) -> bool: @DistillModeRegistry.register_mode -class BypassKDModeDescriptor(KnowledgeDistillationModeDescriptor): - """Class to describe the Bypass Knowledge-Distillation mode. +class LayerwiseKDModeDescriptor(KnowledgeDistillationModeDescriptor): + """Class to describe the Layerwise Knowledge-Distillation mode. The properties of this mode can be inspected via the source code. """ @@ -89,17 +89,17 @@ class BypassKDModeDescriptor(KnowledgeDistillationModeDescriptor): @property def name(self) -> str: """Returns the value (str representation) of the mode.""" - return "bypass_kd" + return "layerwise_kd" @property def config_class(self) -> type[ModeloptBaseConfig]: """Specifies the config class for the mode.""" - return BypassKDConfig + return LayerwiseKDConfig @property def convert(self) -> ConvertEntrypoint: """The mode's entrypoint for converting a model.""" - return _convert_for_bypass + return _convert_for_layerwise @DistillModeRegistry.register_mode @@ -192,9 +192,9 @@ def _convert_for_kd( return distillation_model, metadata -def _convert_for_bypass(model: nn.Module, config: BypassKDConfig) -> ConvertReturnType: - """Function for converting a model to a bypass distillation meta-model.""" - return _convert_for_kd(model, config, model_cls=BypassDistillationModel) +def _convert_for_layerwise(model: nn.Module, config: LayerwiseKDConfig) -> ConvertReturnType: + """Function for converting a model to a layerwise distillation meta-model.""" + return _convert_for_kd(model, config, model_cls=LayerwiseDistillationModel) def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertReturnType: diff --git a/tests/unit/torch/distill/test_bypass.py b/tests/unit/torch/distill/test_layerwise.py similarity index 63% rename from tests/unit/torch/distill/test_bypass.py rename to tests/unit/torch/distill/test_layerwise.py index 291f7733f..aea1dd635 100644 --- a/tests/unit/torch/distill/test_bypass.py +++ b/tests/unit/torch/distill/test_layerwise.py @@ -33,7 +33,7 @@ def tiny_mobilenet(): @pytest.fixture -def bypass_distillation_model(): +def layerwise_distillation_model(): student = tiny_mobilenet().train() config = { "teacher_model": tiny_mobilenet(), @@ -42,15 +42,15 @@ def bypass_distillation_model(): }, "loss_balancer": mtd.StaticLossBalancer(), } - bypass_model = mtd.convert(student, mode=[("bypass_kd", config)]) + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) - return bypass_model + return layerwise_model -def test_bypass_hooks_registration(bypass_distillation_model): - """Test that bypass-specific hooks are registered correctly.""" +def test_layerwise_hooks_registration(layerwise_distillation_model): + """Test that layerwise-specific hooks are registered correctly.""" # Check that student layers have _teacher_layer attribute - for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: assert hasattr(student_layer, "_teacher_layer") assert student_layer._teacher_layer[0] is teacher_layer assert hasattr(student_layer, "_intermediate_output") @@ -60,105 +60,105 @@ def test_bypass_hooks_registration(bypass_distillation_model): assert hasattr(teacher_layer, "_intermediate_output") -def test_bypass_forward_pass(bypass_distillation_model): +def test_layerwise_forward_pass(layerwise_distillation_model): """Test that forward pass works and captures both teacher inputs and outputs.""" - bypass_distillation_model.train() + layerwise_distillation_model.train() input_tensor = get_input_tensor() - bypass_distillation_model(input_tensor) + layerwise_distillation_model(input_tensor) # Check that teacher intermediate inputs and outputs are captured - for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: assert teacher_layer._intermediate_input is None assert teacher_layer._intermediate_output is not None assert student_layer._intermediate_output is not None -def test_bypass_input_injection(bypass_distillation_model): - """Test that teacher inputs are injected into student layers during bypass.""" - bypass_distillation_model.train() +def test_layerwise_input_injection(layerwise_distillation_model): + """Test that teacher inputs are injected into student layers during layerwise distillation.""" + layerwise_distillation_model.train() input_tensor = get_input_tensor() # Perform forward pass - bypass_distillation_model(input_tensor) + layerwise_distillation_model(input_tensor) # Verify that teacher inputs were captured (they should be reset after injection) # After forward, teacher inputs should have been consumed by student layers - for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: # After full forward pass, teacher_layer._intermediate_input should be None - # because it gets consumed by the student bypass hook + # because it gets consumed by the student layerwise hook assert teacher_layer._intermediate_input is None -def test_bypass_loss_computation(bypass_distillation_model): - """Test that loss computation works with bypass distillation.""" - bypass_distillation_model.train() +def test_layerwise_loss_computation(layerwise_distillation_model): + """Test that loss computation works with layerwise distillation.""" + layerwise_distillation_model.train() input_tensor = get_input_tensor() - output = bypass_distillation_model(input_tensor) - loss = bypass_distillation_model.compute_kd_loss(student_loss=output.mean()) + output = layerwise_distillation_model(input_tensor) + loss = layerwise_distillation_model.compute_kd_loss(student_loss=output.mean()) assert isinstance(loss, torch.Tensor) assert loss.numel() == 1 assert loss.requires_grad -def test_bypass_only_student_forward(bypass_distillation_model): - """Test that only_student_forward context manager works with bypass.""" - bypass_distillation_model.train() +def test_layerwise_only_student_forward(layerwise_distillation_model): + """Test that only_student_forward context manager works with layerwise distillation.""" + layerwise_distillation_model.train() input_tensor = get_input_tensor() # When using only_student_forward, teacher inputs should not be captured with warnings.catch_warnings(record=True) as w: - with bypass_distillation_model.only_student_forward(): - bypass_distillation_model(input_tensor) + with layerwise_distillation_model.only_student_forward(): + layerwise_distillation_model(input_tensor) # Should get warning about missing teacher input warning_messages = [str(warning.message) for warning in w] assert any("has no intermediate input stored" in msg for msg in warning_messages) # Verify teacher didn't run - for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: assert teacher_layer._intermediate_input is None assert teacher_layer._intermediate_output is None assert student_layer._intermediate_output is not None -def test_bypass_only_teacher_forward(bypass_distillation_model): - """Test that only_teacher_forward context manager works with bypass.""" - bypass_distillation_model.train() +def test_layerwise_only_teacher_forward(layerwise_distillation_model): + """Test that only_teacher_forward context manager works with layerwise distillation.""" + layerwise_distillation_model.train() input_tensor = get_input_tensor() - with bypass_distillation_model.only_teacher_forward(): - bypass_distillation_model(input_tensor) + with layerwise_distillation_model.only_teacher_forward(): + layerwise_distillation_model(input_tensor) # Verify teacher ran and student didn't - for student_layer, teacher_layer in bypass_distillation_model._layers_to_loss: + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: assert teacher_layer._intermediate_input is not None assert teacher_layer._intermediate_output is not None assert student_layer._intermediate_output is None -def test_bypass_export(bypass_distillation_model): - """Test that export correctly cleans up bypass-specific attributes.""" +def test_layerwise_export(layerwise_distillation_model): + """Test that export correctly cleans up layerwise-specific attributes.""" # Check that _teacher_layer exists before export - for student_layer, _ in bypass_distillation_model._layers_to_loss: + for student_layer, _ in layerwise_distillation_model._layers_to_loss: assert hasattr(student_layer, "_teacher_layer") # Export the model - exported_model = mtd.export(bypass_distillation_model) + exported_model = mtd.export(layerwise_distillation_model) # Check that _teacher_layer is removed after export for student_layer in exported_model.modules(): assert not hasattr(student_layer, "_teacher_layer") assert not hasattr(exported_model, "_teacher_model") - assert not isinstance(exported_model, mtd.BypassDistillationModel) + assert not isinstance(exported_model, mtd.LayerwiseDistillationModel) -def test_bypass_save_restore(bypass_distillation_model, tmp_path): - """Test that save/restore works correctly with bypass distillation.""" - mto.save(bypass_distillation_model, tmp_path / "ckpt.pt") +def test_layerwise_save_restore(layerwise_distillation_model, tmp_path): + """Test that save/restore works correctly with layerwise distillation.""" + mto.save(layerwise_distillation_model, tmp_path / "ckpt.pt") new_student = tiny_mobilenet() restored_model = mto.restore(new_student, tmp_path / "ckpt.pt") @@ -169,8 +169,8 @@ def test_bypass_save_restore(bypass_distillation_model, tmp_path): assert isinstance(restored_model, type(new_student)) -def test_bypass_multiloss(): - """Test bypass distillation with multiple loss functions.""" +def test_layerwise_multiloss(): + """Test layerwise distillation with multiple loss functions.""" student = tiny_mobilenet().train() config = { "teacher_model": tiny_mobilenet(), @@ -180,21 +180,21 @@ def test_bypass_multiloss(): }, "loss_balancer": mtd.StaticLossBalancer([0.5, 0.5]), } - bypass_model = mtd.convert(student, mode=[("bypass_kd", config)]) + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) # Verify hooks are registered for all layers - assert len(bypass_model._layers_to_loss) == 2 + assert len(layerwise_model._layers_to_loss) == 2 # Test forward pass - output = bypass_model(get_input_tensor()) - loss = bypass_model.compute_kd_loss(student_loss=output.mean()) + output = layerwise_model(get_input_tensor()) + loss = layerwise_model.compute_kd_loss(student_loss=output.mean()) assert isinstance(loss, torch.Tensor) assert loss.numel() == 1 -def test_bypass_gradient_flow(): - """Test that gradients flow correctly through bypass distillation.""" +def test_layerwise_gradient_flow(): + """Test that gradients flow correctly through layerwise distillation.""" student = tiny_mobilenet().train() config = { "teacher_model": tiny_mobilenet(), @@ -203,24 +203,24 @@ def test_bypass_gradient_flow(): }, "loss_balancer": None, } - bypass_model = mtd.convert(student, mode=[("bypass_kd", config)]) + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) # Save param snapshots by module param_snapshots = { - name: p.clone() for name, p in bypass_model.named_parameters() if p.requires_grad + name: p.clone() for name, p in layerwise_model.named_parameters() if p.requires_grad } # Forward and backward - optimizer = torch.optim.SGD(bypass_model.parameters(), lr=0.5) + optimizer = torch.optim.SGD(layerwise_model.parameters(), lr=0.5) optimizer.zero_grad() - bypass_model(get_input_tensor()) - loss = bypass_model.compute_kd_loss() + layerwise_model(get_input_tensor()) + loss = layerwise_model.compute_kd_loss() loss.backward() optimizer.step() # Check: parameters in only the target layer(s) are changed updated_any = False - for name, param in bypass_model.named_parameters(): + for name, param in layerwise_model.named_parameters(): if not param.requires_grad: continue changed = not torch.allclose(param, param_snapshots[name]) From 86b9d770feccb575e2339d388d9176b2b312efa6 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 26 Jan 2026 12:51:37 -0800 Subject: [PATCH 7/8] Make second kd registry Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/mode.py | 21 +++++++++++++++------ modelopt/torch/distill/registry.py | 5 ++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index 15467c761..18ccfd9bb 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -24,6 +24,7 @@ from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.opt.conversion import ModeloptStateManager +from modelopt.torch.opt.dynamic import _DMRegistryCls from modelopt.torch.opt.mode import ( ConvertEntrypoint, ConvertReturnType, @@ -36,7 +37,7 @@ from .config import ExportStudentConfig, KDLossConfig, LayerwiseKDConfig from .distillation_model import DistillationModel from .layerwise_distillation_model import LayerwiseDistillationModel -from .registry import DistillationDMRegistry +from .registry import DistillationDMRegistry, LayerwiseDistillationDMRegistry DistillModeRegistry = _ModeRegistryCls("distill") @@ -141,7 +142,10 @@ def save_mode_in_state(self) -> bool: def _convert_for_kd( - model: nn.Module, config: KDLossConfig, model_cls: type[nn.Module] = DistillationModel + model: nn.Module, + config: KDLossConfig, + model_cls: type[nn.Module] = DistillationModel, + model_registry: _DMRegistryCls = DistillationDMRegistry, ) -> ConvertReturnType: """Function for converting a model to a distillation meta-model. @@ -176,12 +180,12 @@ def _convert_for_kd( # initialize distillation model original_cls = type(student) - if original_cls not in DistillationDMRegistry: - DistillationDMRegistry.register({original_cls: "student_class"})(model_cls) + if original_cls not in model_registry: + model_registry.register({original_cls: "student_class"})(model_cls) # TODO (lucasl): look into ways to avoid registering every class manually # (e.g. by just registering nn.Module and disable the "forward" check for the inherited class check - distillation_model = DistillationDMRegistry.convert(student) + distillation_model = model_registry.convert(student) distillation_model.modify( **{**config, "teacher_model": teacher} # overwrite with instantiated teacher ) @@ -194,7 +198,12 @@ def _convert_for_kd( def _convert_for_layerwise(model: nn.Module, config: LayerwiseKDConfig) -> ConvertReturnType: """Function for converting a model to a layerwise distillation meta-model.""" - return _convert_for_kd(model, config, model_cls=LayerwiseDistillationModel) + return _convert_for_kd( + model, + config, + model_cls=LayerwiseDistillationModel, + model_registry=LayerwiseDistillationDMRegistry, + ) def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertReturnType: diff --git a/modelopt/torch/distill/registry.py b/modelopt/torch/distill/registry.py index 905378cc7..69ea47308 100644 --- a/modelopt/torch/distill/registry.py +++ b/modelopt/torch/distill/registry.py @@ -17,7 +17,10 @@ from modelopt.torch.opt.dynamic import _DMRegistryCls -__all__ = ["DistillationDMRegistry"] +__all__ = ["DistillationDMRegistry", "LayerwiseDistillationDMRegistry"] DistillationDMRegistry = _DMRegistryCls(prefix="Distill") # global instance for the registry + +# Need separate one due to registration override issues when using single registry for both. +LayerwiseDistillationDMRegistry = _DMRegistryCls(prefix="LayerwiseDistill") From 8d3b02b9e4268680885b72f7b17c9883397a9ab8 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 26 Jan 2026 13:12:00 -0800 Subject: [PATCH 8/8] Freeze unused layers and lm_head Signed-off-by: Asha Anoosheh --- .../distill/layerwise_distillation_model.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/modelopt/torch/distill/layerwise_distillation_model.py b/modelopt/torch/distill/layerwise_distillation_model.py index 9e91a3d85..e8cbef99f 100644 --- a/modelopt/torch/distill/layerwise_distillation_model.py +++ b/modelopt/torch/distill/layerwise_distillation_model.py @@ -34,6 +34,25 @@ class LayerwiseDistillationModel(DistillationModel): submodule in the teacher. """ + def modify(self, *args, **kwargs): + """Modify the distillation model.""" + super().modify(*args, **kwargs) + + # Freeze student layers except those in criterion. + self.requires_grad_(False) + for student_layer, _ in self._layers_to_loss: + student_layer.requires_grad_(True) + + # Make lm heads (if we have them) no-ops to save compute. + if hasattr(self, "lm_head"): + self._lm_head = self.lm_head + self.lm_head = nn.Identity() + if hasattr(self._teacher_model, "lm_head"): + self._teacher_model._lm_head = self._teacher_model.lm_head + self._teacher_model.lm_head = nn.Identity() + + return self + def _register_hooks(self): """Register hooks for intermediate tensors from teacher models and the student model.""" for student_layer, teacher_layer in self._layers_to_loss: @@ -50,6 +69,12 @@ def export(self): """Export the distillation model.""" for student_layer, _ in self._layers_to_loss: delattr(student_layer, "_teacher_layer") + + if hasattr(self, "_lm_head"): + self.lm_head = self._lm_head + if hasattr(self._teacher_model, "_lm_head"): + self._teacher_model.lm_head = self._teacher_model._lm_head + return super().export()