diff --git a/modelopt/torch/distill/__init__.py b/modelopt/torch/distill/__init__.py index a09aa6b8e..dad15dcc6 100644 --- a/modelopt/torch/distill/__init__.py +++ b/modelopt/torch/distill/__init__.py @@ -19,6 +19,7 @@ from .config import * from .distillation import * from .distillation_model import * +from .layerwise_distillation_model import * from .loss_balancers import * from .losses import * from .registry import * diff --git a/modelopt/torch/distill/config.py b/modelopt/torch/distill/config.py index cfdb3ccb6..74ef15300 100644 --- a/modelopt/torch/distill/config.py +++ b/modelopt/torch/distill/config.py @@ -26,7 +26,7 @@ from .loss_balancers import DistillationLossBalancer -__all__ = ["KDLossConfig"] +__all__ = ["ExportStudentConfig", "KDLossConfig", "LayerwiseKDConfig"] Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007 @@ -120,6 +120,25 @@ def _strict_validate(self) -> None: ) +class LayerwiseKDConfig(KDLossConfig): + """Configuration for the Layerwise Knowledge-Distillation mode. + + This mode is used to distill knowledge from a teacher model to a student model using layerwise distillation. + """ + + @pydantic.field_validator("criterion") + @classmethod + def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]: + """Ensure criterion is a mapping from layer names to loss (potentially entire module).""" + if not isinstance(criterion, dict): + raise ValueError("Layerwise Distillation mode requires explicit criterion pairs.") + if any(key == ("", "") for key in criterion): + raise ValueError( + "Layerwise Distillation mode does not support output-only distillation." + ) + return criterion + + class ExportStudentConfig(ModeloptBaseConfig): """Configuration for the export_student mode. diff --git a/modelopt/torch/distill/distillation_model.py b/modelopt/torch/distill/distillation_model.py index 930b68560..fa344385a 100644 --- a/modelopt/torch/distill/distillation_model.py +++ b/modelopt/torch/distill/distillation_model.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - """Meta-model wrapper to support knowledge-distillation learning.""" import inspect @@ -45,6 +43,7 @@ def _setup(self): self._register_temp_attribute("_loss_modules", nn.ModuleList()) self._register_temp_attribute("_only_teacher_fwd", False) self._register_temp_attribute("_only_student_fwd", False) + self._register_temp_attribute("_hook_handles", set()) # HACK: set model's forward signature to match student class' original. # Needed for HF `transformers.utils.find_labels` which relies on inspecting class signature. @@ -57,13 +56,13 @@ def _setup(self): def modify( self, - teacher_model: nn.Module, # To be frozen. + teacher_model: nn.Module, criterion: dict[ tuple[ - str, # Student model layer whose output to capture. - str, # Teacher model layer whose output to capture. + str, # Student model layer whose output to capture + str, # Teacher model layer whose output to capture ], - Loss, # Loss fn. + Loss, # Loss function ], loss_balancer: DistillationLossBalancer | None = None, expose_minimal_state_dict: bool = True, @@ -71,9 +70,8 @@ def modify( """Constructor. Args: - teacher_model: A teacher model which this class would encapsulate. - criterion: A dictionary mapping the tuple of student and teacher - model layer names to the loss function to apply to that layer pair. + teacher_model: The teacher model (will be frozen). + criterion: Dictionary mapping (student_layer_name, teacher_layer_name) to loss functions. loss_balancer: Instance of :class:`DistillationLossBalancer ` which reduces distillation and non-distillation losses into a single value using some weighing scheme. @@ -106,22 +104,30 @@ def modify( {m for m in self._layers_to_loss.values() if len(list(m.parameters())) > 0} ) - # Disable grad for teacher + # Disable grad for teacher. self._teacher_model.requires_grad_(False) - # Register hooks for intermediate outputs from teacher models and the student model. - # HACK: For inexplicable reasons, sometimes a model will have hooks remain after - # `ato.restore()` so we check if they are present accidentally first. + # Use hooks to caputure relevant activation tensors for loss computation. + self._register_hooks() + + def _register_hooks(self): + """Register hooks for intermediate tensors from teacher models and the student model.""" for student_layer, teacher_layer in self._layers_to_loss: setattr(student_layer, "_intermediate_output", None) - if student_output_capture_fwd_hook not in student_layer._forward_hooks.values(): - student_layer.register_forward_hook(student_output_capture_fwd_hook) + handle_s = student_layer.register_forward_hook(student_output_capture_fwd_hook) setattr(teacher_layer, "_intermediate_output", None) - if teacher_output_capture_fwd_hook not in teacher_layer._forward_hooks.values(): - teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook) + handle_t = teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook) + self._hook_handles.update([handle_s, handle_t]) + + def export(self): + """Export the distillation model.""" + for handle in self._hook_handles: + handle.remove() + self._hook_handles.clear() + return super().export() @property - def teacher_model(self) -> nn.ModuleList: + def teacher_model(self) -> nn.Module: """Fetch the teacher model.""" return self._teacher_model @@ -148,7 +154,7 @@ def hide_teacher_model(self, enable=True): @contextmanager def hide_loss_modules(self, enable=True): - """Context manager to temporarily hide teacher model from the model.""" + """Context manager to temporarily hide loss modules from the model.""" loss_modules = self._loss_modules if enable: self._loss_modules = nn.ModuleList() @@ -169,7 +175,7 @@ def only_teacher_forward(self, enable=True): @contextmanager def only_student_forward(self, enable=True): - """Context manager to temporarily disable forward passes on the student model.""" + """Context manager to temporarily run forward passes only on the student model.""" if enable: self._only_student_fwd = True try: @@ -245,15 +251,13 @@ def compute_kd_loss( Args: student_loss: Original loss computed from the student's output. - loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for - loss-masking situations where the callable changes arguments each iteration. + loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. + Useful for loss-masking situations where the callable changes arguments each iteration. skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar. **loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed. - This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``. Returns: - If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses. - If reduce is False, a dict of student model output loss and layer-wise distillation losses. + A dict of losses if skip_balancer is True, else the scalar total loss. """ if self._loss_balancer is None: assert student_loss is None, "Cannot pass in student loss without using Loss Balancer." @@ -288,9 +292,9 @@ def compute_kd_loss( return loss_total -def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin +def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): """A hook to capture layer output.""" - # NOTE: Defined externally to allow pickling. + # NOTE: Defined externally to allow pickling during DDP initialization. if getattr(module, "_only_teacher_fwd", False): return # Might be hooked on entire model fwd @@ -303,9 +307,9 @@ def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): module._intermediate_output = output -def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin +def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): """A hook to capture layer output.""" - # NOTE: Defined externally to allow pickling. + # NOTE: Defined externally to allow pickling during DDP initialization. if module._intermediate_output is not None: # NOTE: cannot tell if train or eval since teacher is always eval diff --git a/modelopt/torch/distill/layerwise_distillation_model.py b/modelopt/torch/distill/layerwise_distillation_model.py new file mode 100644 index 000000000..e8cbef99f --- /dev/null +++ b/modelopt/torch/distill/layerwise_distillation_model.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Meta-model wrapper to support layerwise-enabled knowledge-distillation learning.""" + +import warnings +from typing import Any + +import torch.nn as nn + +from .distillation_model import DistillationModel, student_output_capture_fwd_hook + +__all__ = ["LayerwiseDistillationModel"] + + +class LayerwiseDistillationModel(DistillationModel): + """Meta-model wrapper to support layerwise-enabled knowledge-distillation learning. + + The LayerwiseDistillationModel is a subclass of the DistillationModel that injects teacher inputs + into the corresponding student layers. This accomodates the case where the student model is the + teacher with specific submodules replaced, which now need to be trained to mimic the original + submodule in the teacher. + """ + + def modify(self, *args, **kwargs): + """Modify the distillation model.""" + super().modify(*args, **kwargs) + + # Freeze student layers except those in criterion. + self.requires_grad_(False) + for student_layer, _ in self._layers_to_loss: + student_layer.requires_grad_(True) + + # Make lm heads (if we have them) no-ops to save compute. + if hasattr(self, "lm_head"): + self._lm_head = self.lm_head + self.lm_head = nn.Identity() + if hasattr(self._teacher_model, "lm_head"): + self._teacher_model._lm_head = self._teacher_model.lm_head + self._teacher_model.lm_head = nn.Identity() + + return self + + def _register_hooks(self): + """Register hooks for intermediate tensors from teacher models and the student model.""" + for student_layer, teacher_layer in self._layers_to_loss: + setattr(student_layer, "_teacher_layer", [teacher_layer]) + handle_s1 = student_layer.register_forward_pre_hook(student_input_bypass_fwd_hook) + setattr(student_layer, "_intermediate_output", None) + handle_s2 = student_layer.register_forward_hook(student_output_capture_fwd_hook) + setattr(teacher_layer, "_intermediate_input", None) + setattr(teacher_layer, "_intermediate_output", None) + handle_t = teacher_layer.register_forward_hook(teacher_input_output_capture_fwd_hook) + self._hook_handles.update([handle_s1, handle_s2, handle_t]) + + def export(self): + """Export the distillation model.""" + for student_layer, _ in self._layers_to_loss: + delattr(student_layer, "_teacher_layer") + + if hasattr(self, "_lm_head"): + self.lm_head = self._lm_head + if hasattr(self._teacher_model, "_lm_head"): + self._teacher_model.lm_head = self._teacher_model._lm_head + + return super().export() + + +def student_input_bypass_fwd_hook(module: nn.Module, input: Any): + """A hook to inject teacher input into corresponding student layer.""" + # NOTE: Defined externally to allow pickling during DDP initialization. + + if getattr(module, "_only_teacher_fwd", False): + return input # Might be hooked on entire model fwd + + teacher_layer = module._teacher_layer[0] + teacher_input = teacher_layer._intermediate_input + if teacher_input is None: + warnings.warn( + f"Teacher's Module `{type(teacher_layer).__name__}` has no intermediate input stored." + " This is expected when the `only_student_forward` context manager is in use." + ) + return input + + teacher_layer._intermediate_input = None # reset + return teacher_input + + +def teacher_input_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): + """A hook to capture layer input and output.""" + # NOTE: Defined externally to allow pickling during DDP initialization. + + if module._intermediate_output is not None: + # NOTE: cannot tell if train or eval since teacher is always eval + warnings.warn( + f"Teacher's Module `{type(module).__name__}` already has an intermediate output stored." + " This is expected when `DistillationModel.compute_kd_loss` is not called in eval mode." + ) + + module._intermediate_input = input + module._intermediate_output = output diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index 75ea751f4..18ccfd9bb 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -21,24 +21,23 @@ import warnings import torch.nn as nn -from torch.nn.modules.loss import _Loss as Loss from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.opt.conversion import ModeloptStateManager +from modelopt.torch.opt.dynamic import _DMRegistryCls from modelopt.torch.opt.mode import ( ConvertEntrypoint, ConvertReturnType, - MetadataDict, ModeDescriptor, RestoreEntrypoint, - UpdateEntrypoint, _ModeRegistryCls, ) from modelopt.torch.utils import init_model_from_model_like, unwrap_model -from .config import ExportStudentConfig, KDLossConfig +from .config import ExportStudentConfig, KDLossConfig, LayerwiseKDConfig from .distillation_model import DistillationModel -from .registry import DistillationDMRegistry +from .layerwise_distillation_model import LayerwiseDistillationModel +from .registry import DistillationDMRegistry, LayerwiseDistillationDMRegistry DistillModeRegistry = _ModeRegistryCls("distill") @@ -75,17 +74,35 @@ def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" raise NotImplementedError(f"{self.name} mode does not support restore.") - @property - def update_for_new_mode(self) -> UpdateEntrypoint: - """The mode's entrypoint for updating the models state for adding new mode.""" - return _reset_kd_state_config - @property def save_mode_in_state(self) -> bool: """Whether the mode should be saved into the modelopt state.""" return False +@DistillModeRegistry.register_mode +class LayerwiseKDModeDescriptor(KnowledgeDistillationModeDescriptor): + """Class to describe the Layerwise Knowledge-Distillation mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "layerwise_kd" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return LayerwiseKDConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return _convert_for_layerwise + + @DistillModeRegistry.register_mode class ExportStudentModeDescriptor(ModeDescriptor): """Class to describe the specific Export mode to be used with Knowledge Distillation. @@ -124,7 +141,12 @@ def save_mode_in_state(self) -> bool: return False -def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType: +def _convert_for_kd( + model: nn.Module, + config: KDLossConfig, + model_cls: type[nn.Module] = DistillationModel, + model_registry: _DMRegistryCls = DistillationDMRegistry, +) -> ConvertReturnType: """Function for converting a model to a distillation meta-model. This is the only utility needed to use the ``modelopt.torch.distill`` API directly. @@ -158,12 +180,12 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType # initialize distillation model original_cls = type(student) - if original_cls not in DistillationDMRegistry: - DistillationDMRegistry.register({original_cls: "student_class"})(DistillationModel) + if original_cls not in model_registry: + model_registry.register({original_cls: "student_class"})(model_cls) # TODO (lucasl): look into ways to avoid registering every class manually # (e.g. by just registering nn.Module and disable the "forward" check for the inherited class check - distillation_model = DistillationDMRegistry.convert(student) + distillation_model = model_registry.convert(student) distillation_model.modify( **{**config, "teacher_model": teacher} # overwrite with instantiated teacher ) @@ -174,11 +196,14 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType return distillation_model, metadata -def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): - """Function for resetting the state's config.""" - config.teacher_model = nn.Module - config.criterion = Loss() - config.loss_balancer = None +def _convert_for_layerwise(model: nn.Module, config: LayerwiseKDConfig) -> ConvertReturnType: + """Function for converting a model to a layerwise distillation meta-model.""" + return _convert_for_kd( + model, + config, + model_cls=LayerwiseDistillationModel, + model_registry=LayerwiseDistillationDMRegistry, + ) def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertReturnType: diff --git a/modelopt/torch/distill/registry.py b/modelopt/torch/distill/registry.py index 905378cc7..69ea47308 100644 --- a/modelopt/torch/distill/registry.py +++ b/modelopt/torch/distill/registry.py @@ -17,7 +17,10 @@ from modelopt.torch.opt.dynamic import _DMRegistryCls -__all__ = ["DistillationDMRegistry"] +__all__ = ["DistillationDMRegistry", "LayerwiseDistillationDMRegistry"] DistillationDMRegistry = _DMRegistryCls(prefix="Distill") # global instance for the registry + +# Need separate one due to registration override issues when using single registry for both. +LayerwiseDistillationDMRegistry = _DMRegistryCls(prefix="LayerwiseDistill") diff --git a/tests/unit/torch/distill/test_distill.py b/tests/unit/torch/distill/test_distill.py index 10241f076..69dec86b7 100644 --- a/tests/unit/torch/distill/test_distill.py +++ b/tests/unit/torch/distill/test_distill.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn from _test_utils.torch.vision_models import get_tiny_mobilenet_and_input -from torch.nn.modules.loss import _Loss as Loss from torchvision.models import alexnet import modelopt.torch.distill as mtd @@ -37,7 +36,7 @@ def tiny_mobilenet(): def tiny_alexnet(): - return alexnet(num_classes=10) # Same class as tiny_mobilenet + return alexnet(num_classes=10) # same num classes as tiny_mobilenet @pytest.fixture @@ -168,13 +167,6 @@ def test_distillation_export(distillation_model, tmp_path): assert not hasattr(model_exported, "_teacher_model") assert hasattr(model_exported, mto.ModeloptStateManager._state_key) - # Test if kd_loss config has been cleaned up - manager = mto.ModeloptStateManager(model_exported) - cfg = manager._state[-2][1]["config"] - assert cfg["teacher_model"] == nn.Module - assert isinstance(next(iter(cfg["criterion"].values())), Loss) - assert cfg["loss_balancer"] is None - mto.save(model_exported, tmp_path / "ckpt.pt") new_student = tiny_mobilenet() new_student_restored = mto.restore(new_student, tmp_path / "ckpt.pt") diff --git a/tests/unit/torch/distill/test_layerwise.py b/tests/unit/torch/distill/test_layerwise.py new file mode 100644 index 000000000..aea1dd635 --- /dev/null +++ b/tests/unit/torch/distill/test_layerwise.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import pytest +import torch +from _test_utils.torch.vision_models import get_tiny_mobilenet_and_input + +import modelopt.torch.distill as mtd +import modelopt.torch.opt as mto + + +def get_input_tensor(): + """Dummy input tensor.""" + return torch.rand(2, 3, 112, 112) + + +def tiny_mobilenet(): + return get_tiny_mobilenet_and_input()[0] + + +@pytest.fixture +def layerwise_distillation_model(): + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.2", "features.2"): torch.nn.MSELoss(), + }, + "loss_balancer": mtd.StaticLossBalancer(), + } + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) + + return layerwise_model + + +def test_layerwise_hooks_registration(layerwise_distillation_model): + """Test that layerwise-specific hooks are registered correctly.""" + # Check that student layers have _teacher_layer attribute + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert hasattr(student_layer, "_teacher_layer") + assert student_layer._teacher_layer[0] is teacher_layer + assert hasattr(student_layer, "_intermediate_output") + + # Check that teacher layers have both input and output capture attributes + assert hasattr(teacher_layer, "_intermediate_input") + assert hasattr(teacher_layer, "_intermediate_output") + + +def test_layerwise_forward_pass(layerwise_distillation_model): + """Test that forward pass works and captures both teacher inputs and outputs.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + layerwise_distillation_model(input_tensor) + + # Check that teacher intermediate inputs and outputs are captured + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is None + assert teacher_layer._intermediate_output is not None + assert student_layer._intermediate_output is not None + + +def test_layerwise_input_injection(layerwise_distillation_model): + """Test that teacher inputs are injected into student layers during layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + # Perform forward pass + layerwise_distillation_model(input_tensor) + + # Verify that teacher inputs were captured (they should be reset after injection) + # After forward, teacher inputs should have been consumed by student layers + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + # After full forward pass, teacher_layer._intermediate_input should be None + # because it gets consumed by the student layerwise hook + assert teacher_layer._intermediate_input is None + + +def test_layerwise_loss_computation(layerwise_distillation_model): + """Test that loss computation works with layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + output = layerwise_distillation_model(input_tensor) + loss = layerwise_distillation_model.compute_kd_loss(student_loss=output.mean()) + + assert isinstance(loss, torch.Tensor) + assert loss.numel() == 1 + assert loss.requires_grad + + +def test_layerwise_only_student_forward(layerwise_distillation_model): + """Test that only_student_forward context manager works with layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + # When using only_student_forward, teacher inputs should not be captured + with warnings.catch_warnings(record=True) as w: + with layerwise_distillation_model.only_student_forward(): + layerwise_distillation_model(input_tensor) + + # Should get warning about missing teacher input + warning_messages = [str(warning.message) for warning in w] + assert any("has no intermediate input stored" in msg for msg in warning_messages) + + # Verify teacher didn't run + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is None + assert teacher_layer._intermediate_output is None + assert student_layer._intermediate_output is not None + + +def test_layerwise_only_teacher_forward(layerwise_distillation_model): + """Test that only_teacher_forward context manager works with layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + with layerwise_distillation_model.only_teacher_forward(): + layerwise_distillation_model(input_tensor) + + # Verify teacher ran and student didn't + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is not None + assert teacher_layer._intermediate_output is not None + assert student_layer._intermediate_output is None + + +def test_layerwise_export(layerwise_distillation_model): + """Test that export correctly cleans up layerwise-specific attributes.""" + # Check that _teacher_layer exists before export + for student_layer, _ in layerwise_distillation_model._layers_to_loss: + assert hasattr(student_layer, "_teacher_layer") + + # Export the model + exported_model = mtd.export(layerwise_distillation_model) + + # Check that _teacher_layer is removed after export + for student_layer in exported_model.modules(): + assert not hasattr(student_layer, "_teacher_layer") + + assert not hasattr(exported_model, "_teacher_model") + assert not isinstance(exported_model, mtd.LayerwiseDistillationModel) + + +def test_layerwise_save_restore(layerwise_distillation_model, tmp_path): + """Test that save/restore works correctly with layerwise distillation.""" + mto.save(layerwise_distillation_model, tmp_path / "ckpt.pt") + + new_student = tiny_mobilenet() + restored_model = mto.restore(new_student, tmp_path / "ckpt.pt") + + # Ensure state is not actually restored (expected behavior from test_distill.py) + manager = mto.ModeloptStateManager(restored_model) + assert not manager.has_state + assert isinstance(restored_model, type(new_student)) + + +def test_layerwise_multiloss(): + """Test layerwise distillation with multiple loss functions.""" + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.1", "features.1"): torch.nn.MSELoss(), + ("features.3", "features.3"): torch.nn.MSELoss(), + }, + "loss_balancer": mtd.StaticLossBalancer([0.5, 0.5]), + } + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) + + # Verify hooks are registered for all layers + assert len(layerwise_model._layers_to_loss) == 2 + + # Test forward pass + output = layerwise_model(get_input_tensor()) + loss = layerwise_model.compute_kd_loss(student_loss=output.mean()) + + assert isinstance(loss, torch.Tensor) + assert loss.numel() == 1 + + +def test_layerwise_gradient_flow(): + """Test that gradients flow correctly through layerwise distillation.""" + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.2", "features.2"): torch.nn.MSELoss(), + }, + "loss_balancer": None, + } + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) + + # Save param snapshots by module + param_snapshots = { + name: p.clone() for name, p in layerwise_model.named_parameters() if p.requires_grad + } + + # Forward and backward + optimizer = torch.optim.SGD(layerwise_model.parameters(), lr=0.5) + optimizer.zero_grad() + layerwise_model(get_input_tensor()) + loss = layerwise_model.compute_kd_loss() + loss.backward() + optimizer.step() + + # Check: parameters in only the target layer(s) are changed + updated_any = False + for name, param in layerwise_model.named_parameters(): + if not param.requires_grad: + continue + changed = not torch.allclose(param, param_snapshots[name]) + if "features.2" in name: + assert changed, f"'{name}' parameters did not change!" + updated_any = True + else: + assert not changed, f"Parameters in unrelated layer '{name}' changed!" + assert updated_any, ( + "No parameters were updated in 'features.2' or related layers during training" + )