3131 _acceptance_fn_default ,
3232 _find_buf_dtype ,
3333 _find_buf_dtype2 ,
34- _find_inplace_dtype ,
3534 _to_device_supported_dtype ,
3635)
3736
@@ -79,8 +78,8 @@ def __call__(self, x, out=None, order="K"):
7978 )
8079
8180 if out .shape != x .shape :
82- raise TypeError (
83- "The shape of input and output arrays are inconsistent."
81+ raise ValueError (
82+ "The shape of input and output arrays are inconsistent. "
8483 f"Expected output shape is { x .shape } , got { out .shape } "
8584 )
8685
@@ -104,7 +103,7 @@ def __call__(self, x, out=None, order="K"):
104103 dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
105104 is None
106105 ):
107- raise TypeError (
106+ raise ExecutionPlacementError (
108107 "Input and output allocation queues are not compatible"
109108 )
110109
@@ -302,8 +301,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
302301 o1_kind_num = _weak_type_num_kind (o1_dtype )
303302 o2_kind_num = _strong_dtype_num_kind (o2_dtype )
304303 if o1_kind_num > o2_kind_num :
305- if isinstance (o1_dtype , WeakBooleanType ):
306- return dpt .bool , o2_dtype
307304 if isinstance (o1_dtype , WeakIntegralType ):
308305 return dpt .int64 , o2_dtype
309306 if isinstance (o1_dtype , WeakComplexType ):
@@ -323,8 +320,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
323320 o1_kind_num = _strong_dtype_num_kind (o1_dtype )
324321 o2_kind_num = _weak_type_num_kind (o2_dtype )
325322 if o2_kind_num > o1_kind_num :
326- if isinstance (o2_dtype , WeakBooleanType ):
327- return o1_dtype , dpt .bool
328323 if isinstance (o2_dtype , WeakIntegralType ):
329324 return o1_dtype , dpt .int64
330325 if isinstance (o2_dtype , WeakComplexType ):
@@ -383,14 +378,6 @@ def __repr__(self):
383378 return f"<{ self .__name__ } '{ self .name_ } '>"
384379
385380 def __call__ (self , o1 , o2 , out = None , order = "K" ):
386- # FIXME: replace with check against base array
387- # when views can be identified
388- if self .binary_inplace_fn_ :
389- if o1 is out :
390- return self ._inplace (o1 , o2 )
391- elif o2 is out :
392- return self ._inplace (o2 , o1 )
393-
394381 if order not in ["K" , "C" , "F" , "A" ]:
395382 order = "K"
396383 q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -472,31 +459,90 @@ def __call__(self, o1, o2, out=None, order="K"):
472459 "supported types according to the casting rule ''safe''."
473460 )
474461
462+ orig_out = out
475463 if out is not None :
476464 if not isinstance (out , dpt .usm_ndarray ):
477465 raise TypeError (
478466 f"output array must be of usm_ndarray type, got { type (out )} "
479467 )
480468
481469 if out .shape != res_shape :
482- raise TypeError (
483- "The shape of input and output arrays are inconsistent."
470+ raise ValueError (
471+ "The shape of input and output arrays are inconsistent. "
484472 f"Expected output shape is { o1_shape } , got { out .shape } "
485473 )
486474
487- if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
488- raise TypeError ("Input and output arrays have memory overlap" )
475+ if res_dt != out .dtype :
476+ raise TypeError (
477+ f"Output array of type { res_dt } is needed,"
478+ f"got { out .dtype } "
479+ )
489480
490481 if (
491- dpctl .utils .get_execution_queue (
492- (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
493- )
482+ dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue ))
494483 is None
495484 ):
496- raise TypeError (
485+ raise ExecutionPlacementError (
497486 "Input and output allocation queues are not compatible"
498487 )
499488
489+ if isinstance (o1 , dpt .usm_ndarray ):
490+ if ti ._array_overlap (o1 , out ) and buf1_dt is None :
491+ if not ti ._same_logical_tensors (o1 , out ):
492+ out = dpt .empty_like (out )
493+ elif self .binary_inplace_fn_ is not None :
494+ # if there is a dedicated in-place kernel
495+ # it can be called here, otherwise continues
496+ if isinstance (o2 , dpt .usm_ndarray ):
497+ src2 = o2
498+ if (
499+ ti ._array_overlap (o2 , out )
500+ and not ti ._same_logical_tensors (o2 , out )
501+ and buf2_dt is None
502+ ):
503+ buf2_dt = o2_dtype
504+ else :
505+ src2 = dpt .asarray (
506+ o2 , dtype = o2_dtype , sycl_queue = exec_q
507+ )
508+ if buf2_dt is None :
509+ if src2 .shape != res_shape :
510+ src2 = dpt .broadcast_to (src2 , res_shape )
511+ ht_ , _ = self .binary_inplace_fn_ (
512+ lhs = o1 , rhs = src2 , sycl_queue = exec_q
513+ )
514+ ht_ .wait ()
515+ else :
516+ buf2 = dpt .empty_like (src2 , dtype = buf2_dt )
517+ (
518+ ht_copy_ev ,
519+ copy_ev ,
520+ ) = ti ._copy_usm_ndarray_into_usm_ndarray (
521+ src = src2 , dst = buf2 , sycl_queue = exec_q
522+ )
523+
524+ buf2 = dpt .broadcast_to (buf2 , res_shape )
525+ ht_ , _ = self .binary_inplace_fn_ (
526+ lhs = o1 ,
527+ rhs = buf2 ,
528+ sycl_queue = exec_q ,
529+ depends = [copy_ev ],
530+ )
531+ ht_copy_ev .wait ()
532+ ht_ .wait ()
533+
534+ return out
535+
536+ if isinstance (o2 , dpt .usm_ndarray ):
537+ if (
538+ ti ._array_overlap (o2 , out )
539+ and not ti ._same_logical_tensors (o2 , out )
540+ and buf2_dt is None
541+ ):
542+ # should not reach if out is reallocated
543+ # after being checked against o1
544+ out = dpt .empty_like (out )
545+
500546 if isinstance (o1 , dpt .usm_ndarray ):
501547 src1 = o1
502548 else :
@@ -532,19 +578,24 @@ def __call__(self, o1, o2, out=None, order="K"):
532578 sycl_queue = exec_q ,
533579 order = order ,
534580 )
535- else :
536- if res_dt != out .dtype :
537- raise TypeError (
538- f"Output array of type { res_dt } is needed,"
539- f"got { out .dtype } "
540- )
541-
542- src1 = dpt .broadcast_to (src1 , res_shape )
543- src2 = dpt .broadcast_to (src2 , res_shape )
544- ht_ , _ = self .binary_fn_ (
581+ if src1 .shape != res_shape :
582+ src1 = dpt .broadcast_to (src1 , res_shape )
583+ if src2 .shape != res_shape :
584+ src2 = dpt .broadcast_to (src2 , res_shape )
585+ ht_binary_ev , binary_ev = self .binary_fn_ (
545586 src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
546587 )
547- ht_ .wait ()
588+ if not (orig_out is None or orig_out is out ):
589+ # Copy the out data from temporary buffer to original memory
590+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
591+ src = out ,
592+ dst = orig_out ,
593+ sycl_queue = exec_q ,
594+ depends = [binary_ev ],
595+ )
596+ ht_copy_out_ev .wait ()
597+ out = orig_out
598+ ht_binary_ev .wait ()
548599 return out
549600 elif buf1_dt is None :
550601 if order == "K" :
@@ -575,18 +626,28 @@ def __call__(self, o1, o2, out=None, order="K"):
575626 f"Output array of type { res_dt } is needed,"
576627 f"got { out .dtype } "
577628 )
578-
579- src1 = dpt .broadcast_to (src1 , res_shape )
629+ if src1 . shape != res_shape :
630+ src1 = dpt .broadcast_to (src1 , res_shape )
580631 buf2 = dpt .broadcast_to (buf2 , res_shape )
581- ht_ , _ = self .binary_fn_ (
632+ ht_binary_ev , binary_ev = self .binary_fn_ (
582633 src1 = src1 ,
583634 src2 = buf2 ,
584635 dst = out ,
585636 sycl_queue = exec_q ,
586637 depends = [copy_ev ],
587638 )
639+ if not (orig_out is None or orig_out is out ):
640+ # Copy the out data from temporary buffer to original memory
641+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
642+ src = out ,
643+ dst = orig_out ,
644+ sycl_queue = exec_q ,
645+ depends = [binary_ev ],
646+ )
647+ ht_copy_out_ev .wait ()
648+ out = orig_out
588649 ht_copy_ev .wait ()
589- ht_ .wait ()
650+ ht_binary_ev .wait ()
590651 return out
591652 elif buf2_dt is None :
592653 if order == "K" :
@@ -611,24 +672,29 @@ def __call__(self, o1, o2, out=None, order="K"):
611672 sycl_queue = exec_q ,
612673 order = order ,
613674 )
614- else :
615- if res_dt != out .dtype :
616- raise TypeError (
617- f"Output array of type { res_dt } is needed,"
618- f"got { out .dtype } "
619- )
620675
621676 buf1 = dpt .broadcast_to (buf1 , res_shape )
622- src2 = dpt .broadcast_to (src2 , res_shape )
623- ht_ , _ = self .binary_fn_ (
677+ if src2 .shape != res_shape :
678+ src2 = dpt .broadcast_to (src2 , res_shape )
679+ ht_binary_ev , binary_ev = self .binary_fn_ (
624680 src1 = buf1 ,
625681 src2 = src2 ,
626682 dst = out ,
627683 sycl_queue = exec_q ,
628684 depends = [copy_ev ],
629685 )
686+ if not (orig_out is None or orig_out is out ):
687+ # Copy the out data from temporary buffer to original memory
688+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
689+ src = out ,
690+ dst = orig_out ,
691+ sycl_queue = exec_q ,
692+ depends = [binary_ev ],
693+ )
694+ ht_copy_out_ev .wait ()
695+ out = orig_out
630696 ht_copy_ev .wait ()
631- ht_ .wait ()
697+ ht_binary_ev .wait ()
632698 return out
633699
634700 if order in ["K" , "A" ]:
@@ -665,11 +731,6 @@ def __call__(self, o1, o2, out=None, order="K"):
665731 sycl_queue = exec_q ,
666732 order = order ,
667733 )
668- else :
669- if res_dt != out .dtype :
670- raise TypeError (
671- f"Output array of type { res_dt } is needed, got { out .dtype } "
672- )
673734
674735 buf1 = dpt .broadcast_to (buf1 , res_shape )
675736 buf2 = dpt .broadcast_to (buf2 , res_shape )
@@ -682,116 +743,3 @@ def __call__(self, o1, o2, out=None, order="K"):
682743 )
683744 dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
684745 return out
685-
686- def _inplace (self , lhs , val ):
687- if self .binary_inplace_fn_ is None :
688- raise ValueError (
689- f"In-place operation not supported for ufunc '{ self .name_ } '"
690- )
691- if not isinstance (lhs , dpt .usm_ndarray ):
692- raise TypeError (
693- f"Expected dpctl.tensor.usm_ndarray, got { type (lhs )} "
694- )
695- q1 , lhs_usm_type = _get_queue_usm_type (lhs )
696- q2 , val_usm_type = _get_queue_usm_type (val )
697- if q2 is None :
698- exec_q = q1
699- usm_type = lhs_usm_type
700- else :
701- exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
702- if exec_q is None :
703- raise ExecutionPlacementError (
704- "Execution placement can not be unambiguously inferred "
705- "from input arguments."
706- )
707- usm_type = dpctl .utils .get_coerced_usm_type (
708- (
709- lhs_usm_type ,
710- val_usm_type ,
711- )
712- )
713- dpctl .utils .validate_usm_type (usm_type , allow_none = False )
714- lhs_shape = _get_shape (lhs )
715- val_shape = _get_shape (val )
716- if not all (
717- isinstance (s , (tuple , list ))
718- for s in (
719- lhs_shape ,
720- val_shape ,
721- )
722- ):
723- raise TypeError (
724- "Shape of arguments can not be inferred. "
725- "Arguments are expected to be "
726- "lists, tuples, or both"
727- )
728- try :
729- res_shape = _broadcast_shape_impl (
730- [
731- lhs_shape ,
732- val_shape ,
733- ]
734- )
735- except ValueError :
736- raise ValueError (
737- "operands could not be broadcast together with shapes "
738- f"{ lhs_shape } and { val_shape } "
739- )
740- if res_shape != lhs_shape :
741- raise ValueError (
742- f"output shape { lhs_shape } does not match "
743- f"broadcast shape { res_shape } "
744- )
745- sycl_dev = exec_q .sycl_device
746- lhs_dtype = lhs .dtype
747- val_dtype = _get_dtype (val , sycl_dev )
748- if not _validate_dtype (val_dtype ):
749- raise ValueError ("Input operand of unsupported type" )
750-
751- lhs_dtype , val_dtype = _resolve_weak_types (
752- lhs_dtype , val_dtype , sycl_dev
753- )
754-
755- buf_dt = _find_inplace_dtype (
756- lhs_dtype , val_dtype , self .result_type_resolver_fn_ , sycl_dev
757- )
758-
759- if buf_dt is None :
760- raise TypeError (
761- f"In-place '{ self .name_ } ' does not support input types "
762- f"({ lhs_dtype } , { val_dtype } ), "
763- "and the inputs could not be safely coerced to any "
764- "supported types according to the casting rule ''safe''."
765- )
766-
767- if isinstance (val , dpt .usm_ndarray ):
768- rhs = val
769- overlap = ti ._array_overlap (lhs , rhs )
770- else :
771- rhs = dpt .asarray (val , dtype = val_dtype , sycl_queue = exec_q )
772- overlap = False
773-
774- if buf_dt == val_dtype and overlap is False :
775- rhs = dpt .broadcast_to (rhs , res_shape )
776- ht_ , _ = self .binary_inplace_fn_ (
777- lhs = lhs , rhs = rhs , sycl_queue = exec_q
778- )
779- ht_ .wait ()
780-
781- else :
782- buf = dpt .empty_like (rhs , dtype = buf_dt )
783- ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
784- src = rhs , dst = buf , sycl_queue = exec_q
785- )
786-
787- buf = dpt .broadcast_to (buf , res_shape )
788- ht_ , _ = self .binary_inplace_fn_ (
789- lhs = lhs ,
790- rhs = buf ,
791- sycl_queue = exec_q ,
792- depends = [copy_ev ],
793- )
794- ht_copy_ev .wait ()
795- ht_ .wait ()
796-
797- return lhs
0 commit comments