4141]
4242
4343
44+ def _map_int_to_type (n , dt ):
45+ assert isinstance (n , int )
46+ assert n > 0
47+ if dt == dpt .int8 :
48+ return ((n + 128 ) % 256 ) - 128
49+ elif dt == dpt .uint8 :
50+ return n % 256
51+ elif dt == dpt .int16 :
52+ return ((n + 32768 ) % 65536 ) - 32768
53+ elif dt == dpt .uint16 :
54+ return n % 65536
55+ return n
56+
57+
4458def test_matrix_transpose ():
4559 get_queue_or_skip ()
4660
@@ -702,8 +716,8 @@ def test_vecdot_1d(dtype):
702716 v2 = dpt .ones (n , dtype = dtype )
703717
704718 r = dpt .vecdot (v1 , v2 )
705-
706- assert r == n
719+ expected_value = _map_int_to_type ( n , r . dtype )
720+ assert r == expected_value
707721
708722
709723@pytest .mark .parametrize ("dtype" , _numeric_types )
@@ -722,7 +736,8 @@ def test_vecdot_3d(dtype):
722736 m1 ,
723737 m2 ,
724738 )
725- assert dpt .all (r == n )
739+ expected_value = _map_int_to_type (n , r .dtype )
740+ assert dpt .all (r == expected_value )
726741
727742
728743@pytest .mark .parametrize ("dtype" , _numeric_types )
@@ -741,7 +756,8 @@ def test_vecdot_axis(dtype):
741756 m1 ,
742757 m2 ,
743758 )
744- assert dpt .all (r == n )
759+ expected_value = _map_int_to_type (n , r .dtype )
760+ assert dpt .all (r == expected_value )
745761
746762
747763@pytest .mark .parametrize ("dtype" , _numeric_types )
@@ -775,6 +791,7 @@ def test_vecdot_strided(dtype):
775791 m1 ,
776792 m2 ,
777793 )
794+ ref = _map_int_to_type (ref , r .dtype )
778795 assert dpt .all (r == ref )
779796
780797
0 commit comments