diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index df42fb1376..108b33fd86 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -6,6 +6,7 @@ from typing import Dict, Optional, List, Tuple from contextlib import contextmanager +import warnings import torch import nvdlfw_inspect.api as debug_api @@ -298,15 +299,26 @@ def inspect_tensor( API call used to collect the data about the tensor after process_tensor()/quantization. """ assert rowwise_quantized_tensor is columnwise_quantized_tensor - assert ( - quantizer is not None - ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." + + # Skip logging if quantizer is None (layer runs in high precision) + if quantizer is None: + warnings.warn( + f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': layer runs in high precision (no quantizer)." + ) + return quantized_tensor = rowwise_quantized_tensor - assert isinstance( - quantized_tensor, QuantizedTensor - ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." + # Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision) + if not isinstance(quantized_tensor, QuantizedTensor): + warnings.warn( + f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected QuantizedTensor, got {type(quantized_tensor).__name__})." + ) + return + recipe_name = _get_recipe_name(quantizer) for stat in config["stats"]: diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index ec2b3c38d3..18ac8619f3 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -6,6 +6,7 @@ from typing import Dict, Optional from contextlib import contextmanager +import warnings import torch import nvdlfw_inspect.api as debug_api @@ -152,23 +153,34 @@ def inspect_tensor( API call used to collect the data about the tensor after process_tensor()/quantization. """ assert rowwise_quantized_tensor is columnwise_quantized_tensor - assert ( - quantizer is not None - ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer." + + # Skip logging if quantizer is None (layer runs in high precision) + if quantizer is None: + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': layer runs in high precision (no quantizer)." + ) + return quantized_tensor = rowwise_quantized_tensor - # Ensure we're working with NVFP4 tensors + # Skip logging if not NVFP4 quantizer (incompatible precision) if not isinstance(quantizer, NVFP4Quantizer): - raise ValueError( - "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, " - f"but got {type(quantizer).__name__}" + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected NVFP4Quantizer, got {type(quantizer).__name__})." ) - - assert isinstance(quantized_tensor, NVFP4TensorStorage), ( - "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a" - " NVFP4TensorStorage." - ) + return + + # Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision) + if not isinstance(quantized_tensor, NVFP4TensorStorage): + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected NVFP4TensorStorage, got {type(quantized_tensor).__name__})." + ) + return for stat in config["stats"]: self.check_if_stat_is_supported(stat)