@@ -783,6 +783,17 @@ def test_tensordot_axes_errors():
783783 dpt .tensordot (m1 , m2 , axes = - 1 )
784784
785785
786+ # tests for gh-1570
787+ def test_tensordot_gemm_small_k_m ():
788+ get_queue_or_skip ()
789+
790+ x1 = dpt .asarray (1 , dtype = "i2" )
791+ x2 = dpt .asarray ([0 , 1 , 0 , 0 ], dtype = "i2" )
792+
793+ res = dpt .tensordot (x1 , x2 , axes = 0 )
794+ assert dpt .all (x2 == res )
795+
796+
786797@pytest .mark .parametrize ("dtype" , _numeric_types )
787798def test_vecdot_1d (dtype ):
788799 q = get_queue_or_skip ()
@@ -943,3 +954,29 @@ def test_vecdot_type_promotion(dt1, dt2):
943954 assert r .shape == tuple ()
944955 assert r .dtype == mul .dtype
945956 assert dpt .allclose (r , dpt .sum (mul , dtype = mul .dtype ))
957+
958+
959+ def test_vecdot_broadcast_o1_buffer ():
960+ get_queue_or_skip ()
961+
962+ v1 = dpt .arange (10 , dtype = "i2" )
963+ v2 = dpt .ones ((5 , 10 ), dtype = "i4" )
964+
965+ res1 = dpt .vecdot (v1 , v2 )
966+ assert res1 .shape == (5 ,)
967+
968+ res2 = dpt .vecdot (v2 , v1 )
969+ assert res2 .shape == (5 ,)
970+
971+
972+ def test_vecdot_contig_small ():
973+ get_queue_or_skip ()
974+
975+ n = 1
976+ for dt in [dpt .int16 , dpt .int32 , dpt .complex64 ]:
977+ v1 = dpt .zeros ((10 , n ), dtype = dt )
978+ v2 = dpt .ones_like (v1 , dtype = dt )
979+ v1 [- 1 ] = 1
980+ res = dpt .vecdot (v1 , v2 )
981+ assert dpt .all (res [:- 1 ] == 0 )
982+ assert res [- 1 ] == n
0 commit comments