@@ -930,140 +930,138 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
930930 return out
931931
932932 def _inplace_op (self , o1 , o2 ):
933- if self .binary_inplace_fn_ is not None :
934- if not isinstance (o1 , dpt .usm_ndarray ):
935- raise TypeError (
936- "Expected first argument to be "
937- f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
938- )
939- if not o1 .flags .writable :
940- raise ValueError ("provided left-hand side array is read-only" )
941- q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
942- q2 , o2_usm_type = _get_queue_usm_type (o2 )
943- if q2 is None :
944- exec_q = q1
945- res_usm_type = o1_usm_type
946- else :
947- exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
948- if exec_q is None :
949- raise ExecutionPlacementError (
950- "Execution placement can not be unambiguously inferred "
951- "from input arguments."
952- )
953- res_usm_type = dpctl .utils .get_coerced_usm_type (
954- (
955- o1_usm_type ,
956- o2_usm_type ,
957- )
958- )
959- dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
960- o1_shape = o1 .shape
961- o2_shape = _get_shape (o2 )
962- if not isinstance (o2_shape , (tuple , list )):
963- raise TypeError (
964- "Shape of second argument can not be inferred. "
965- "Expected list or tuple."
966- )
967- try :
968- res_shape = _broadcast_shape_impl (
969- [
970- o1_shape ,
971- o2_shape ,
972- ]
973- )
974- except ValueError :
975- raise ValueError (
976- "operands could not be broadcast together with shapes "
977- f"{ o1_shape } and { o2_shape } "
933+ if self .binary_inplace_fn_ is None :
934+ raise ValueError (
935+ "binary function does not have a dedicated in-place "
936+ "implementation"
937+ )
938+ if not isinstance (o1 , dpt .usm_ndarray ):
939+ raise TypeError (
940+ "Expected first argument to be "
941+ f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
942+ )
943+ if not o1 .flags .writable :
944+ raise ValueError ("provided left-hand side array is read-only" )
945+ q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
946+ q2 , o2_usm_type = _get_queue_usm_type (o2 )
947+ if q2 is None :
948+ exec_q = q1
949+ res_usm_type = o1_usm_type
950+ else :
951+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
952+ if exec_q is None :
953+ raise ExecutionPlacementError (
954+ "Execution placement can not be unambiguously inferred "
955+ "from input arguments."
978956 )
979-
980- if res_shape != o1_shape :
981- raise ValueError (
982- "The shape of the non-broadcastable left-hand "
983- f"side { o1_shape } is inconsistent with the "
984- f"broadcast shape { res_shape } ."
957+ res_usm_type = dpctl .utils .get_coerced_usm_type (
958+ (
959+ o1_usm_type ,
960+ o2_usm_type ,
985961 )
986-
987- sycl_dev = exec_q .sycl_device
988- o1_dtype = o1 .dtype
989- o2_dtype = _get_dtype (o2 , sycl_dev )
990- if not _validate_dtype (o2_dtype ):
991- raise ValueError ("Operand has an unsupported data type" )
992-
993- o1_dtype , o2_dtype = self .weak_type_resolver_ (
994- o1_dtype , o2_dtype , sycl_dev
962+ )
963+ dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
964+ o1_shape = o1 .shape
965+ o2_shape = _get_shape (o2 )
966+ if not isinstance (o2_shape , (tuple , list )):
967+ raise TypeError (
968+ "Shape of second argument can not be inferred. "
969+ "Expected list or tuple."
970+ )
971+ try :
972+ res_shape = _broadcast_shape_impl (
973+ [
974+ o1_shape ,
975+ o2_shape ,
976+ ]
977+ )
978+ except ValueError :
979+ raise ValueError (
980+ "operands could not be broadcast together with shapes "
981+ f"{ o1_shape } and { o2_shape } "
995982 )
996983
997- buf_dt , res_dt = _find_buf_dtype_in_place_op (
998- o1_dtype ,
999- o2_dtype ,
1000- self . result_type_resolver_fn_ ,
1001- sycl_dev ,
984+ if res_shape != o1_shape :
985+ raise ValueError (
986+ "The shape of the non-broadcastable left-hand "
987+ f"side { o1_shape } is inconsistent with the "
988+ f"broadcast shape { res_shape } ."
1002989 )
1003990
1004- if res_dt is None :
1005- raise ValueError (
1006- f"function '{ self .name_ } ' does not support input types "
1007- f"({ o1_dtype } , { o2_dtype } ), "
1008- "and the inputs could not be safely coerced to any "
1009- "supported types according to the casting rule "
1010- "''same_kind''."
1011- )
991+ sycl_dev = exec_q .sycl_device
992+ o1_dtype = o1 .dtype
993+ o2_dtype = _get_dtype (o2 , sycl_dev )
994+ if not _validate_dtype (o2_dtype ):
995+ raise ValueError ("Operand has an unsupported data type" )
1012996
1013- if res_dt != o1_dtype :
1014- raise ValueError (
1015- f"Output array of type { res_dt } is needed, "
1016- f"got { o1_dtype } "
1017- )
997+ o1_dtype , o2_dtype = self .weak_type_resolver_ (
998+ o1_dtype , o2_dtype , sycl_dev
999+ )
10181000
1019- _manager = SequentialOrderManager [exec_q ]
1020- if isinstance (o2 , dpt .usm_ndarray ):
1021- src2 = o2
1022- if (
1023- ti ._array_overlap (o2 , o1 )
1024- and not ti ._same_logical_tensors (o2 , o1 )
1025- and buf_dt is None
1026- ):
1027- buf_dt = o2_dtype
1028- else :
1029- src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1030- if buf_dt is None :
1031- if src2 .shape != res_shape :
1032- src2 = dpt .broadcast_to (src2 , res_shape )
1033- dep_evs = _manager .submitted_events
1034- ht_ , comp_ev = self .binary_inplace_fn_ (
1035- lhs = o1 ,
1036- rhs = src2 ,
1037- sycl_queue = exec_q ,
1038- depends = dep_evs ,
1039- )
1040- _manager .add_event_pair (ht_ , comp_ev )
1041- else :
1042- buf = dpt .empty_like (src2 , dtype = buf_dt )
1043- dep_evs = _manager .submitted_events
1044- (
1045- ht_copy_ev ,
1046- copy_ev ,
1047- ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1048- src = src2 ,
1049- dst = buf ,
1050- sycl_queue = exec_q ,
1051- depends = dep_evs ,
1052- )
1053- _manager .add_event_pair (ht_copy_ev , copy_ev )
1001+ buf_dt , res_dt = _find_buf_dtype_in_place_op (
1002+ o1_dtype ,
1003+ o2_dtype ,
1004+ self .result_type_resolver_fn_ ,
1005+ sycl_dev ,
1006+ )
10541007
1055- buf = dpt . broadcast_to ( buf , res_shape )
1056- ht_ , bf_ev = self . binary_inplace_fn_ (
1057- lhs = o1 ,
1058- rhs = buf ,
1059- sycl_queue = exec_q ,
1060- depends = [ copy_ev ],
1061- )
1062- _manager . add_event_pair ( ht_ , bf_ev )
1008+ if res_dt is None :
1009+ raise ValueError (
1010+ f"function ' { self . name_ } ' does not support input types "
1011+ f"( { o1_dtype } , { o2_dtype } ), "
1012+ "and the inputs could not be safely coerced to any "
1013+ "supported types according to the casting rule "
1014+ "''same_kind''."
1015+ )
10631016
1064- return o1
1065- else :
1017+ if res_dt != o1_dtype :
10661018 raise ValueError (
1067- "binary function does not have a dedicated in-place "
1068- "implementation"
1019+ f"Output array of type { res_dt } is needed, " f"got { o1_dtype } "
10691020 )
1021+
1022+ _manager = SequentialOrderManager [exec_q ]
1023+ if isinstance (o2 , dpt .usm_ndarray ):
1024+ src2 = o2
1025+ if (
1026+ ti ._array_overlap (o2 , o1 )
1027+ and not ti ._same_logical_tensors (o2 , o1 )
1028+ and buf_dt is None
1029+ ):
1030+ buf_dt = o2_dtype
1031+ else :
1032+ src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1033+ if buf_dt is None :
1034+ if src2 .shape != res_shape :
1035+ src2 = dpt .broadcast_to (src2 , res_shape )
1036+ dep_evs = _manager .submitted_events
1037+ ht_ , comp_ev = self .binary_inplace_fn_ (
1038+ lhs = o1 ,
1039+ rhs = src2 ,
1040+ sycl_queue = exec_q ,
1041+ depends = dep_evs ,
1042+ )
1043+ _manager .add_event_pair (ht_ , comp_ev )
1044+ else :
1045+ buf = dpt .empty_like (src2 , dtype = buf_dt )
1046+ dep_evs = _manager .submitted_events
1047+ (
1048+ ht_copy_ev ,
1049+ copy_ev ,
1050+ ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1051+ src = src2 ,
1052+ dst = buf ,
1053+ sycl_queue = exec_q ,
1054+ depends = dep_evs ,
1055+ )
1056+ _manager .add_event_pair (ht_copy_ev , copy_ev )
1057+
1058+ buf = dpt .broadcast_to (buf , res_shape )
1059+ ht_ , bf_ev = self .binary_inplace_fn_ (
1060+ lhs = o1 ,
1061+ rhs = buf ,
1062+ sycl_queue = exec_q ,
1063+ depends = [copy_ev ],
1064+ )
1065+ _manager .add_event_pair (ht_ , bf_ev )
1066+
1067+ return o1
0 commit comments