@@ -351,17 +351,19 @@ def vecdot(x1, x2, axis=-1):
351351 x2_nd = x2 .ndim
352352 x1_shape = x1 .shape
353353 x2_shape = x2 .shape
354+ if axis >= 0 :
355+ raise ValueError ("`axis` must be negative" )
356+ axis = operator .index (axis )
357+ x1_axis = normalize_axis_index (axis , x1_nd )
358+ x2_axis = normalize_axis_index (axis , x2_nd )
359+ if x1_shape [x1_axis ] != x2_shape [x2_axis ]:
360+ raise ValueError (
361+ "given axis must have the same shape for `x1` and `x2`"
362+ )
354363 if x1_nd > x2_nd :
355364 x2_shape = (1 ,) * (x1_nd - x2_nd ) + x2_shape
356- x2_nd = len (x2_shape )
357365 elif x2_nd > x1_nd :
358366 x1_shape = (1 ,) * (x2_nd - x1_nd ) + x1_shape
359- x1_nd = len (x1_shape )
360- axis = normalize_axis_index (operator .index (axis ), min (x1_nd , x2_nd ))
361- if x1_shape [axis ] != x2_shape [axis ]:
362- raise ValueError (
363- "given axis must have the same shape for `x1` and `x2`"
364- )
365367 try :
366368 broadcast_sh = _broadcast_shape_impl (
367369 [
@@ -371,8 +373,10 @@ def vecdot(x1, x2, axis=-1):
371373 )
372374 except ValueError :
373375 raise ValueError ("mismatch in `vecdot` dimensions" )
376+ broadcast_nd = len (broadcast_sh )
377+ contracted_axis = normalize_axis_index (axis , broadcast_nd )
374378 res_sh = tuple (
375- [broadcast_sh [i ] for i in range (len ( broadcast_sh )) if i != axis ]
379+ [broadcast_sh [i ] for i in range (broadcast_nd ) if i != contracted_axis ]
376380 )
377381 # type validation
378382 sycl_dev = exec_q .sycl_device
@@ -410,9 +414,8 @@ def vecdot(x1, x2, axis=-1):
410414 x1 = dpt .broadcast_to (x1 , broadcast_sh )
411415 if x2 .shape != broadcast_sh :
412416 x2 = dpt .broadcast_to (x2 , broadcast_sh )
413- x1 = dpt .moveaxis (x1 , axis , - 1 )
414- x2 = dpt .moveaxis (x2 , axis , - 1 )
415-
417+ x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
418+ x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
416419 out = dpt .empty (
417420 res_sh ,
418421 dtype = res_dt ,
@@ -455,8 +458,8 @@ def vecdot(x1, x2, axis=-1):
455458 x1 = dpt .broadcast_to (x1 , broadcast_sh )
456459 if buf2 .shape != broadcast_sh :
457460 buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
458- x1 = dpt .moveaxis (x1 , axis , - 1 )
459- buf2 = dpt .moveaxis (buf2 , axis , - 1 )
461+ x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
462+ buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
460463 out = dpt .empty (
461464 res_sh ,
462465 dtype = res_dt ,
@@ -497,8 +500,8 @@ def vecdot(x1, x2, axis=-1):
497500 buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
498501 if x2 .shape != broadcast_sh :
499502 x2 = dpt .broadcast_to (x2 , broadcast_sh )
500- buf1 = dpt .moveaxis (buf1 , axis , - 1 )
501- x2 = dpt .moveaxis (x2 , axis , - 1 )
503+ buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
504+ x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
502505 out = dpt .empty (
503506 res_sh ,
504507 dtype = res_dt ,
@@ -544,8 +547,8 @@ def vecdot(x1, x2, axis=-1):
544547 buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
545548 if buf2 .shape != broadcast_sh :
546549 buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
547- buf1 = dpt .moveaxis (buf1 , axis , - 1 )
548- buf2 = dpt .moveaxis (buf2 , axis , - 1 )
550+ buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
551+ buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
549552 out = dpt .empty (
550553 res_sh ,
551554 dtype = res_dt ,
0 commit comments