Skip to content

Commit 2886cbc

Browse files
pggPLtimmoon10
andauthored
[PyTorch debug] Fix test for debug tools (#2507)
* Skip delayed wgrad tests in distributed numerics when debug mode is enabled Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent b215116 commit 2886cbc

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

qa/L1_pytorch_distributed_unittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
4444

4545
pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
4646
# standard numerics tests with initialized debug
47-
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
47+
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
4848

4949
if [ "$RET" -ne 0 ]; then
5050
echo "Error in the following test cases:$FAILED_CASES"

tests/pytorch/distributed/run_numerics.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@
3838
NCCL_WORLD = None
3939
LOSS_FN = nn.MSELoss()
4040
QUANTIZATION = None
41+
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED") or "0")
4142

42-
if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
43+
if NVTE_TEST_NVINSPECT_ENABLED:
4344
# The numerics of all the layers should work the same,
4445
# when debug=True. I fed them with dummy feature
4546
# to prevent switching off debug, which can happen if
@@ -745,6 +746,8 @@ def test_linear():
745746
for kwargs in kwargs_list:
746747
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
747748
continue
749+
if kwargs.get("delay_wgrad_compute", False) and NVTE_TEST_NVINSPECT_ENABLED:
750+
continue
748751
for parallel_mode in ["column", "row"]:
749752
for sequence_parallel in [False, True]:
750753
_test_linear(parallel_mode, sequence_parallel, **kwargs)
@@ -924,6 +927,8 @@ def test_layernorm_linear():
924927
]
925928

926929
for kwargs in kwargs_list:
930+
if kwargs.get("delay_wgrad_compute", False) and NVTE_TEST_NVINSPECT_ENABLED:
931+
continue
927932
for parallel_mode in ["column"]:
928933
for sequence_parallel in [False, True]:
929934
_test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs)
@@ -1034,6 +1039,8 @@ def test_layernorm_mlp():
10341039
]
10351040

10361041
for kwargs in kwargs_list:
1042+
if kwargs.get("delay_wgrad_compute", False) and NVTE_TEST_NVINSPECT_ENABLED:
1043+
continue
10371044
for set_parallel_mode in [True]:
10381045
for sequence_parallel in [False, True]:
10391046
_test_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)

0 commit comments

Comments
 (0)