@@ -351,7 +351,9 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
351351 )
352352 st = list (X .strides )
353353 perm = sorted (
354- range (X .ndim ), key = lambda d : builtins .abs (st [d ]), reverse = True
354+ range (X .ndim ),
355+ key = lambda d : builtins .abs (st [d ]) if X .shape [d ] > 1 else 0 ,
356+ reverse = True ,
355357 )
356358 inv_perm = sorted (range (X .ndim ), key = lambda i : perm [i ])
357359 sh = X .shape
@@ -395,9 +397,14 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
395397 max_ndim = max (nd1 , nd2 )
396398 st1 += [0 ] * (max_ndim - len (st1 ))
397399 st2 += [0 ] * (max_ndim - len (st2 ))
400+ sh1 = list (X1 .shape ) + [0 ] * (max_ndim - nd1 )
401+ sh2 = list (X2 .shape ) + [0 ] * (max_ndim - nd2 )
398402 perm = sorted (
399403 range (max_ndim ),
400- key = lambda d : (builtins .abs (st1 [d ]), builtins .abs (st2 [d ])),
404+ key = lambda d : (
405+ builtins .abs (st1 [d ]) if sh1 [d ] > 1 else 0 ,
406+ builtins .abs (st2 [d ]) if sh2 [d ] > 1 else 0 ,
407+ ),
401408 reverse = True ,
402409 )
403410 inv_perm = sorted (range (max_ndim ), key = lambda i : perm [i ])
@@ -417,6 +424,74 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
417424 return dpt .permute_dims (R , inv_perm )
418425
419426
427+ def _empty_like_triple_orderK (X1 , X2 , X3 , dt , res_shape , usm_type , dev ):
428+ if not isinstance (X1 , dpt .usm_ndarray ):
429+ raise TypeError (f"Expected usm_ndarray, got { type (X1 )} " )
430+ if not isinstance (X2 , dpt .usm_ndarray ):
431+ raise TypeError (f"Expected usm_ndarray, got { type (X2 )} " )
432+ if not isinstance (X3 , dpt .usm_ndarray ):
433+ raise TypeError (f"Expected usm_ndarray, got { type (X3 )} " )
434+ nd1 = X1 .ndim
435+ nd2 = X2 .ndim
436+ nd3 = X3 .ndim
437+ if X1 .shape == res_shape and X2 .shape == res_shape and len (res_shape ) > nd3 :
438+ return _empty_like_pair_orderK (X1 , X2 , dt , res_shape , usm_type , dev )
439+ elif (
440+ X2 .shape == res_shape and X3 .shape == res_shape and len (res_shape ) > nd1
441+ ):
442+ return _empty_like_pair_orderK (X2 , X3 , dt , res_shape , usm_type , dev )
443+ elif (
444+ X1 .shape == res_shape and X3 .shape == res_shape and len (res_shape ) > nd2
445+ ):
446+ return _empty_like_pair_orderK (X1 , X3 , dt , res_shape , usm_type , dev )
447+ fl1 = X1 .flags
448+ fl2 = X2 .flags
449+ fl3 = X3 .flags
450+ if fl1 ["C" ] or fl2 ["C" ] or fl3 ["C" ]:
451+ return dpt .empty (
452+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "C"
453+ )
454+ if fl1 ["F" ] and fl2 ["F" ] and fl3 ["F" ]:
455+ return dpt .empty (
456+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "F"
457+ )
458+ st1 = list (X1 .strides )
459+ st2 = list (X2 .strides )
460+ st3 = list (X3 .strides )
461+ max_ndim = max (nd1 , nd2 , nd3 )
462+ st1 += [0 ] * (max_ndim - len (st1 ))
463+ st2 += [0 ] * (max_ndim - len (st2 ))
464+ st3 += [0 ] * (max_ndim - len (st3 ))
465+ sh1 = list (X1 .shape ) + [0 ] * (max_ndim - nd1 )
466+ sh2 = list (X2 .shape ) + [0 ] * (max_ndim - nd2 )
467+ sh3 = list (X3 .shape ) + [0 ] * (max_ndim - nd3 )
468+ perm = sorted (
469+ range (max_ndim ),
470+ key = lambda d : (
471+ builtins .abs (st1 [d ]) if sh1 [d ] > 1 else 0 ,
472+ builtins .abs (st2 [d ]) if sh2 [d ] > 1 else 0 ,
473+ builtins .abs (st3 [d ]) if sh3 [d ] > 1 else 0 ,
474+ ),
475+ reverse = True ,
476+ )
477+ inv_perm = sorted (range (max_ndim ), key = lambda i : perm [i ])
478+ st1_sorted = [st1 [i ] for i in perm ]
479+ st2_sorted = [st2 [i ] for i in perm ]
480+ st3_sorted = [st3 [i ] for i in perm ]
481+ sh = res_shape
482+ sh_sorted = tuple (sh [i ] for i in perm )
483+ R = dpt .empty (sh_sorted , dtype = dt , usm_type = usm_type , device = dev , order = "C" )
484+ if max (min (st1_sorted ), min (st2_sorted ), min (st3_sorted )) < 0 :
485+ sl = tuple (
486+ slice (None , None , - 1 )
487+ if (st1_sorted [i ] < 0 and st2_sorted [i ] < 0 and st3_sorted [i ] < 0 )
488+ else slice (None , None , None )
489+ for i in range (nd1 )
490+ )
491+ R = R [sl ]
492+ return dpt .permute_dims (R , inv_perm )
493+
494+
420495def copy (usm_ary , order = "K" ):
421496 """copy(ary, order="K")
422497
0 commit comments