@@ -930,126 +930,134 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
930930 return out
931931
932932 def _inplace_op (self , o1 , o2 ):
933- if not isinstance (o1 , dpt .usm_ndarray ):
934- raise TypeError (
935- "Expected first argument to be "
936- f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
937- )
938- if not o1 .flags .writable :
939- raise ValueError ("provided left-hand side array is read-only" )
940- q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
941- q2 , o2_usm_type = _get_queue_usm_type (o2 )
942- if q2 is None :
943- exec_q = q1
944- res_usm_type = o1_usm_type
945- else :
946- exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
947- if exec_q is None :
948- raise ExecutionPlacementError (
949- "Execution placement can not be unambiguously inferred "
950- "from input arguments."
933+ if self .binary_inplace_fn_ is not None :
934+ if not isinstance (o1 , dpt .usm_ndarray ):
935+ raise TypeError (
936+ "Expected first argument to be "
937+ f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
951938 )
952- res_usm_type = dpctl .utils .get_coerced_usm_type (
953- (
954- o1_usm_type ,
955- o2_usm_type ,
939+ if not o1 .flags .writable :
940+ raise ValueError ("provided left-hand side array is read-only" )
941+ q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
942+ q2 , o2_usm_type = _get_queue_usm_type (o2 )
943+ if q2 is None :
944+ exec_q = q1
945+ res_usm_type = o1_usm_type
946+ else :
947+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
948+ if exec_q is None :
949+ raise ExecutionPlacementError (
950+ "Execution placement can not be unambiguously inferred "
951+ "from input arguments."
952+ )
953+ res_usm_type = dpctl .utils .get_coerced_usm_type (
954+ (
955+ o1_usm_type ,
956+ o2_usm_type ,
957+ )
956958 )
959+ dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
960+ o1_shape = o1 .shape
961+ o2_shape = _get_shape (o2 )
962+ if not isinstance (o2_shape , (tuple , list )):
963+ raise TypeError (
964+ "Shape of second argument can not be inferred. "
965+ "Expected list or tuple."
966+ )
967+ try :
968+ res_shape = _broadcast_shape_impl (
969+ [
970+ o1_shape ,
971+ o2_shape ,
972+ ]
973+ )
974+ except ValueError :
975+ raise ValueError (
976+ "operands could not be broadcast together with shapes "
977+ f"{ o1_shape } and { o2_shape } "
978+ )
979+ if res_shape != o1_shape :
980+ raise ValueError ("" )
981+ sycl_dev = exec_q .sycl_device
982+ o1_dtype = o1 .dtype
983+ o2_dtype = _get_dtype (o2 , sycl_dev )
984+ if not _validate_dtype (o2_dtype ):
985+ raise ValueError ("Operand has an unsupported data type" )
986+
987+ o1_dtype , o2_dtype = self .weak_type_resolver_ (
988+ o1_dtype , o2_dtype , sycl_dev
957989 )
958- dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
959- o1_shape = o1 .shape
960- o2_shape = _get_shape (o2 )
961- if not isinstance (o2_shape , (tuple , list )):
962- raise TypeError (
963- "Shape of second argument can not be inferred. "
964- "Expected list or tuple."
965- )
966- try :
967- res_shape = _broadcast_shape_impl (
968- [
969- o1_shape ,
970- o2_shape ,
971- ]
972- )
973- except ValueError :
974- raise ValueError (
975- "operands could not be broadcast together with shapes "
976- f"{ o1_shape } and { o2_shape } "
990+
991+ buf_dt , res_dt = _find_buf_dtype_in_place_op (
992+ o1_dtype ,
993+ o2_dtype ,
994+ self .result_type_resolver_fn_ ,
995+ sycl_dev ,
977996 )
978- if res_shape != o1_shape :
979- raise ValueError ("" )
980- sycl_dev = exec_q .sycl_device
981- o1_dtype = o1 .dtype
982- o2_dtype = _get_dtype (o2 , sycl_dev )
983- if not _validate_dtype (o2_dtype ):
984- raise ValueError ("Operand has an unsupported data type" )
985997
986- o1_dtype , o2_dtype = self .weak_type_resolver_ (
987- o1_dtype , o2_dtype , sycl_dev
988- )
998+ if res_dt is None :
999+ raise ValueError (
1000+ f"function '{ self .name_ } ' does not support input types "
1001+ f"({ o1_dtype } , { o2_dtype } ), "
1002+ "and the inputs could not be safely coerced to any "
1003+ "supported types according to the casting rule "
1004+ "''same_kind''."
1005+ )
9891006
990- buf_dt , res_dt = _find_buf_dtype_in_place_op (
991- o1_dtype ,
992- o2_dtype ,
993- self .result_type_resolver_fn_ ,
994- sycl_dev ,
995- )
1007+ if res_dt != o1_dtype :
1008+ raise ValueError (
1009+ f"Output array of type { res_dt } is needed, "
1010+ f"got { o1_dtype } "
1011+ )
9961012
997- if res_dt is None :
998- raise ValueError (
999- f"function '{ self .name_ } ' does not support input types "
1000- f"({ o1_dtype } , { o2_dtype } ), "
1001- "and the inputs could not be safely coerced to any "
1002- "supported types according to the casting rule ''same_kind''."
1003- )
1013+ _manager = SequentialOrderManager [exec_q ]
1014+ if isinstance (o2 , dpt .usm_ndarray ):
1015+ src2 = o2
1016+ if (
1017+ ti ._array_overlap (o2 , o1 )
1018+ and not ti ._same_logical_tensors (o2 , o1 )
1019+ and buf_dt is None
1020+ ):
1021+ buf_dt = o2_dtype
1022+ else :
1023+ src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1024+ if buf_dt is None :
1025+ if src2 .shape != res_shape :
1026+ src2 = dpt .broadcast_to (src2 , res_shape )
1027+ dep_evs = _manager .submitted_events
1028+ ht_ , comp_ev = self .binary_inplace_fn_ (
1029+ lhs = o1 ,
1030+ rhs = src2 ,
1031+ sycl_queue = exec_q ,
1032+ depends = dep_evs ,
1033+ )
1034+ _manager .add_event_pair (ht_ , comp_ev )
1035+ else :
1036+ buf = dpt .empty_like (src2 , dtype = buf_dt )
1037+ dep_evs = _manager .submitted_events
1038+ (
1039+ ht_copy_ev ,
1040+ copy_ev ,
1041+ ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1042+ src = src2 ,
1043+ dst = buf ,
1044+ sycl_queue = exec_q ,
1045+ depends = dep_evs ,
1046+ )
1047+ _manager .add_event_pair (ht_copy_ev , copy_ev )
10041048
1005- if res_dt != o1_dtype :
1006- raise ValueError (
1007- f"Output array of type { res_dt } is needed, " f"got { o1_dtype } "
1008- )
1049+ buf = dpt .broadcast_to (buf , res_shape )
1050+ ht_ , bf_ev = self .binary_inplace_fn_ (
1051+ lhs = o1 ,
1052+ rhs = buf ,
1053+ sycl_queue = exec_q ,
1054+ depends = [copy_ev ],
1055+ )
1056+ _manager .add_event_pair (ht_ , bf_ev )
10091057
1010- _manager = SequentialOrderManager [exec_q ]
1011- if isinstance (o2 , dpt .usm_ndarray ):
1012- src2 = o2
1013- if (
1014- ti ._array_overlap (o2 , o1 )
1015- and not ti ._same_logical_tensors (o2 , o1 )
1016- and buf_dt is None
1017- ):
1018- buf_dt = o2_dtype
1019- else :
1020- src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1021- if buf_dt is None :
1022- if src2 .shape != res_shape :
1023- src2 = dpt .broadcast_to (src2 , res_shape )
1024- dep_evs = _manager .submitted_events
1025- ht_ , comp_ev = self .binary_inplace_fn_ (
1026- lhs = o1 ,
1027- rhs = src2 ,
1028- sycl_queue = exec_q ,
1029- depends = dep_evs ,
1030- )
1031- _manager .add_event_pair (ht_ , comp_ev )
1058+ return o1
10321059 else :
1033- buf = dpt .empty_like (src2 , dtype = buf_dt )
1034- dep_evs = _manager .submitted_events
1035- (
1036- ht_copy_ev ,
1037- copy_ev ,
1038- ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1039- src = src2 ,
1040- dst = buf ,
1041- sycl_queue = exec_q ,
1042- depends = dep_evs ,
1043- )
1044- _manager .add_event_pair (ht_copy_ev , copy_ev )
1045-
1046- buf = dpt .broadcast_to (buf , res_shape )
1047- ht_ , bf_ev = self .binary_inplace_fn_ (
1048- lhs = o1 ,
1049- rhs = buf ,
1050- sycl_queue = exec_q ,
1051- depends = [copy_ev ],
1060+ raise ValueError (
1061+ "binary function does not have a dedicated in-place "
1062+ "implementation"
10521063 )
1053- _manager .add_event_pair (ht_ , bf_ev )
1054-
1055- return o1
0 commit comments