@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4747 self .unary_fn_ = unary_dp_impl_fn
4848 self .__doc__ = docs
4949
50- def __call__ (self , x , order = "K" ):
50+ def __call__ (self , x , out = None , order = "K" ):
5151 if not isinstance (x , dpt .usm_ndarray ):
5252 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
53+
54+ if out is not None :
55+ if not isinstance (out , dpt .usm_ndarray ):
56+ raise TypeError (
57+ f"output array must be of usm_ndarray type, got { type (out )} "
58+ )
59+
60+ if out .shape != x .shape :
61+ raise TypeError (
62+ "The shape of input and output arrays are inconsistent."
63+ f"Expected output shape is { x .shape } , got { out .shape } "
64+ )
65+
66+ if ti ._array_overlap (x , out ):
67+ raise TypeError ("Input and output arrays have memory overlap" )
68+
69+ if (
70+ dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
71+ is None
72+ ):
73+ raise TypeError (
74+ "Input and output allocation queues are not compatible"
75+ )
76+
5377 if order not in ["C" , "F" , "K" , "A" ]:
5478 order = "K"
5579 buf_dt , res_dt = _find_buf_dtype (
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
5983 raise RuntimeError
6084 exec_q = x .sycl_queue
6185 if buf_dt is None :
62- if order == "K" :
63- r = _empty_like_orderK (x , res_dt )
86+ if out is None :
87+ if order == "K" :
88+ out = _empty_like_orderK (x , res_dt )
89+ else :
90+ if order == "A" :
91+ order = "F" if x .flags .f_contiguous else "C"
92+ out = dpt .empty_like (x , dtype = res_dt , order = order )
6493 else :
65- if order == "A" :
66- order = "F" if x .flags .f_contiguous else "C"
67- r = dpt .empty_like (x , dtype = res_dt , order = order )
94+ if res_dt != out .dtype :
95+ raise TypeError (
96+ f"Expected output array of type { res_dt } is supported"
97+ f", got { out .dtype } "
98+ )
6899
69- ht , _ = self .unary_fn_ (x , r , sycl_queue = exec_q )
100+ ht , _ = self .unary_fn_ (x , out , sycl_queue = exec_q )
70101 ht .wait ()
71102
72- return r
103+ return out
73104 if order == "K" :
74105 buf = _empty_like_orderK (x , buf_dt )
75106 else :
@@ -80,16 +111,23 @@ def __call__(self, x, order="K"):
80111 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
81112 src = x , dst = buf , sycl_queue = exec_q
82113 )
83- if order == "K" :
84- r = _empty_like_orderK (buf , res_dt )
114+ if out is None :
115+ if order == "K" :
116+ out = _empty_like_orderK (buf , res_dt )
117+ else :
118+ out = dpt .empty_like (buf , dtype = res_dt , order = order )
85119 else :
86- r = dpt .empty_like (buf , dtype = res_dt , order = order )
120+ if buf_dt != out .dtype :
121+ raise TypeError (
122+ f"Expected output array of type { buf_dt } is supported,"
123+ f"got { out .dtype } "
124+ )
87125
88- ht , _ = self .unary_fn_ (buf , r , sycl_queue = exec_q , depends = [copy_ev ])
126+ ht , _ = self .unary_fn_ (buf , out , sycl_queue = exec_q , depends = [copy_ev ])
89127 ht_copy_ev .wait ()
90128 ht .wait ()
91129
92- return r
130+ return out
93131
94132
95133def _get_queue_usm_type (o ):
@@ -281,7 +319,7 @@ def __str__(self):
281319 def __repr__ (self ):
282320 return f"<BinaryElementwiseFunc '{ self .name_ } '>"
283321
284- def __call__ (self , o1 , o2 , order = "K" ):
322+ def __call__ (self , o1 , o2 , out = None , order = "K" ):
285323 if order not in ["K" , "C" , "F" , "A" ]:
286324 order = "K"
287325 q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -358,6 +396,31 @@ def __call__(self, o1, o2, order="K"):
358396 "supported types according to the casting rule ''safe''."
359397 )
360398
399+ if out is not None :
400+ if not isinstance (out , dpt .usm_ndarray ):
401+ raise TypeError (
402+ f"output array must be of usm_ndarray type, got { type (out )} "
403+ )
404+
405+ if out .shape != o1_shape or out .shape != o2_shape :
406+ raise TypeError (
407+ "The shape of input and output arrays are inconsistent."
408+ f"Expected output shape is { o1_shape } , got { out .shape } "
409+ )
410+
411+ if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
412+ raise TypeError ("Input and output arrays have memory overlap" )
413+
414+ if (
415+ dpctl .utils .get_execution_queue (
416+ (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
417+ )
418+ is None
419+ ):
420+ raise TypeError (
421+ "Input and output allocation queues are not compatible"
422+ )
423+
361424 if isinstance (o1 , dpt .usm_ndarray ):
362425 src1 = o1
363426 else :
@@ -368,37 +431,45 @@ def __call__(self, o1, o2, order="K"):
368431 src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
369432
370433 if buf1_dt is None and buf2_dt is None :
371- if order == "K" :
372- r = _empty_like_pair_orderK (
373- src1 , src2 , res_dt , res_usm_type , exec_q
374- )
375- else :
376- if order == "A" :
377- order = (
378- "F"
379- if all (
380- arr .flags .f_contiguous
381- for arr in (
382- src1 ,
383- src2 ,
434+ if out is None :
435+ if order == "K" :
436+ out = _empty_like_pair_orderK (
437+ src1 , src2 , res_dt , res_usm_type , exec_q
438+ )
439+ else :
440+ if order == "A" :
441+ order = (
442+ "F"
443+ if all (
444+ arr .flags .f_contiguous
445+ for arr in (
446+ src1 ,
447+ src2 ,
448+ )
384449 )
450+ else "C"
385451 )
386- else "C"
452+ out = dpt .empty (
453+ res_shape ,
454+ dtype = res_dt ,
455+ usm_type = res_usm_type ,
456+ sycl_queue = exec_q ,
457+ order = order ,
387458 )
388- r = dpt . empty (
389- res_shape ,
390- dtype = res_dt ,
391- usm_type = res_usm_type ,
392- sycl_queue = exec_q ,
393- order = order ,
394- )
459+ else :
460+ if res_dt != out . dtype :
461+ raise TypeError (
462+ f"Output array of type { res_dt } is needed,"
463+ f"got { out . dtype } "
464+ )
465+
395466 src1 = dpt .broadcast_to (src1 , res_shape )
396467 src2 = dpt .broadcast_to (src2 , res_shape )
397468 ht_ , _ = self .binary_fn_ (
398- src1 = src1 , src2 = src2 , dst = r , sycl_queue = exec_q
469+ src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
399470 )
400471 ht_ .wait ()
401- return r
472+ return out
402473 elif buf1_dt is None :
403474 if order == "K" :
404475 buf2 = _empty_like_orderK (src2 , buf2_dt )
@@ -409,30 +480,38 @@ def __call__(self, o1, o2, order="K"):
409480 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
410481 src = src2 , dst = buf2 , sycl_queue = exec_q
411482 )
412- if order == "K" :
413- r = _empty_like_pair_orderK (
414- src1 , buf2 , res_dt , res_usm_type , exec_q
415- )
483+ if out is None :
484+ if order == "K" :
485+ out = _empty_like_pair_orderK (
486+ src1 , buf2 , res_dt , res_usm_type , exec_q
487+ )
488+ else :
489+ out = dpt .empty (
490+ res_shape ,
491+ dtype = res_dt ,
492+ usm_type = res_usm_type ,
493+ sycl_queue = exec_q ,
494+ order = order ,
495+ )
416496 else :
417- r = dpt .empty (
418- res_shape ,
419- dtype = res_dt ,
420- usm_type = res_usm_type ,
421- sycl_queue = exec_q ,
422- order = order ,
423- )
497+ if res_dt != out .dtype :
498+ raise TypeError (
499+ f"Output array of type { res_dt } is needed,"
500+ f"got { out .dtype } "
501+ )
502+
424503 src1 = dpt .broadcast_to (src1 , res_shape )
425504 buf2 = dpt .broadcast_to (buf2 , res_shape )
426505 ht_ , _ = self .binary_fn_ (
427506 src1 = src1 ,
428507 src2 = buf2 ,
429- dst = r ,
508+ dst = out ,
430509 sycl_queue = exec_q ,
431510 depends = [copy_ev ],
432511 )
433512 ht_copy_ev .wait ()
434513 ht_ .wait ()
435- return r
514+ return out
436515 elif buf2_dt is None :
437516 if order == "K" :
438517 buf1 = _empty_like_orderK (src1 , buf1_dt )
@@ -443,30 +522,38 @@ def __call__(self, o1, o2, order="K"):
443522 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
444523 src = src1 , dst = buf1 , sycl_queue = exec_q
445524 )
446- if order == "K" :
447- r = _empty_like_pair_orderK (
448- buf1 , src2 , res_dt , res_usm_type , exec_q
449- )
525+ if out is None :
526+ if order == "K" :
527+ out = _empty_like_pair_orderK (
528+ buf1 , src2 , res_dt , res_usm_type , exec_q
529+ )
530+ else :
531+ out = dpt .empty (
532+ res_shape ,
533+ dtype = res_dt ,
534+ usm_type = res_usm_type ,
535+ sycl_queue = exec_q ,
536+ order = order ,
537+ )
450538 else :
451- r = dpt .empty (
452- res_shape ,
453- dtype = res_dt ,
454- usm_type = res_usm_type ,
455- sycl_queue = exec_q ,
456- order = order ,
457- )
539+ if res_dt != out .dtype :
540+ raise TypeError (
541+ f"Output array of type { res_dt } is needed,"
542+ f"got { out .dtype } "
543+ )
544+
458545 buf1 = dpt .broadcast_to (buf1 , res_shape )
459546 src2 = dpt .broadcast_to (src2 , res_shape )
460547 ht_ , _ = self .binary_fn_ (
461548 src1 = buf1 ,
462549 src2 = src2 ,
463- dst = r ,
550+ dst = out ,
464551 sycl_queue = exec_q ,
465552 depends = [copy_ev ],
466553 )
467554 ht_copy_ev .wait ()
468555 ht_ .wait ()
469- return r
556+ return out
470557
471558 if order in ["K" , "A" ]:
472559 if src1 .flags .f_contiguous and src2 .flags .f_contiguous :
@@ -489,26 +576,33 @@ def __call__(self, o1, o2, order="K"):
489576 ht_copy2_ev , copy2_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
490577 src = src2 , dst = buf2 , sycl_queue = exec_q
491578 )
492- if order == "K" :
493- r = _empty_like_pair_orderK (
494- buf1 , buf2 , res_dt , res_usm_type , exec_q
495- )
579+ if out is None :
580+ if order == "K" :
581+ out = _empty_like_pair_orderK (
582+ buf1 , buf2 , res_dt , res_usm_type , exec_q
583+ )
584+ else :
585+ out = dpt .empty (
586+ res_shape ,
587+ dtype = res_dt ,
588+ usm_type = res_usm_type ,
589+ sycl_queue = exec_q ,
590+ order = order ,
591+ )
496592 else :
497- r = dpt .empty (
498- res_shape ,
499- dtype = res_dt ,
500- usm_type = res_usm_type ,
501- sycl_queue = exec_q ,
502- order = order ,
503- )
593+ if res_dt != out .dtype :
594+ raise TypeError (
595+ f"Output array of type { res_dt } is needed, got { out .dtype } "
596+ )
597+
504598 buf1 = dpt .broadcast_to (buf1 , res_shape )
505599 buf2 = dpt .broadcast_to (buf2 , res_shape )
506600 ht_ , _ = self .binary_fn_ (
507601 src1 = buf1 ,
508602 src2 = buf2 ,
509- dst = r ,
603+ dst = out ,
510604 sycl_queue = exec_q ,
511605 depends = [copy1_ev , copy2_ev ],
512606 )
513607 dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
514- return r
608+ return out
0 commit comments