File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change 1010
1111from torchao .float8 .float8_utils import _round_scale_down_to_power_of_2
1212from torchao .testing .utils import skip_if_rocm
13+ from torchao .utils import get_current_accelerator_device
14+
15+ _DEVICE = get_current_accelerator_device ()
1316
1417
1518# source for notable single-precision cases:
1619# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
17- @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
20+ @unittest .skipIf (not torch .accelerator .is_available (), "GPU not available" )
1821@pytest .mark .parametrize (
1922 "test_case" ,
2023 [
@@ -38,8 +41,8 @@ def test_round_scale_down_to_power_of_2_valid_inputs(
3841):
3942 test_case_name , input , expected_result = test_case
4043 input_tensor , expected_tensor = (
41- torch .tensor (input , dtype = torch .float32 ).cuda ( ),
42- torch .tensor (expected_result , dtype = torch .float32 ).cuda ( ),
44+ torch .tensor (input , dtype = torch .float32 ).to ( _DEVICE ),
45+ torch .tensor (expected_result , dtype = torch .float32 ).to ( _DEVICE ),
4346 )
4447 result = _round_scale_down_to_power_of_2 (input_tensor )
4548
You can’t perform that action at this time.
0 commit comments