@@ -389,45 +389,75 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
389389 return R
390390
391391
392- def _mock_extract (ary , ary_mask , p ):
393- exec_q = dpctl .utils .get_execution_queue (
394- (
395- ary .sycl_queue ,
396- ary_mask .sycl_queue ,
392+ def _extract_impl (ary , ary_mask , axis = 0 ):
393+ """Extract elements of ary by applying mask starting from slot
394+ dimension axis"""
395+ if not isinstance (ary , dpt .usm_ndarray ):
396+ raise TypeError (
397+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
398+ )
399+ if not isinstance (ary_mask , dpt .usm_ndarray ):
400+ raise TypeError (
401+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
397402 )
403+ exec_q = dpctl .utils .get_execution_queue (
404+ (ary .sycl_queue , ary_mask .sycl_queue )
398405 )
399406 if exec_q is None :
400407 raise dpctl .utils .ExecutionPlacementError (
401- "Can not automatically determine where to allocate the "
402- "result or performance execution. "
403- "Use `usm_ndarray.to_device` method to migrate data to "
404- "be associated with the same queue."
408+ "arrays have different associated queues. "
409+ "Use `Y.to_device(X.device)` to migrate."
405410 )
406-
407- res_usm_type = dpctl .utils .get_coerced_usm_type (
408- (
409- ary .usm_type ,
410- ary_mask .usm_type ,
411+ ary_nd = ary .ndim
412+ pp = normalize_axis_index (operator .index (axis ), ary_nd )
413+ mask_nd = ary_mask .ndim
414+ if pp < 0 or pp + mask_nd > ary_nd :
415+ raise ValueError (
416+ "Parameter p is inconsistent with input array dimensions"
411417 )
418+ mask_nelems = ary_mask .size
419+ cumsum = dpt .empty (mask_nelems , dtype = dpt .int64 , device = ary_mask .device )
420+ exec_q = cumsum .sycl_queue
421+ mask_count = ti .mask_positions (ary_mask , cumsum , sycl_queue = exec_q )
422+ dst_shape = ary .shape [:pp ] + (mask_count ,) + ary .shape [pp + mask_nd :]
423+ dst = dpt .empty (
424+ dst_shape , dtype = ary .dtype , usm_type = ary .usm_type , device = ary .device
412425 )
413- ary_np = dpt .asnumpy (ary )
414- mask_np = dpt .asnumpy (ary_mask )
415- res_np = ary_np [(slice (None ),) * p + (mask_np ,)]
416- res = dpt .empty (
417- res_np .shape , dtype = ary .dtype , usm_type = res_usm_type , sycl_queue = exec_q
426+ hev , _ = ti ._extract (
427+ src = ary ,
428+ cumsum = cumsum ,
429+ axis_start = pp ,
430+ axis_end = pp + mask_nd ,
431+ dst = dst ,
432+ sycl_queue = exec_q ,
418433 )
419- res [...] = res_np
420- return res
434+ hev . wait ()
435+ return dst
421436
422437
423- def _mock_nonzero (ary ):
438+ def _nonzero_impl (ary ):
424439 if not isinstance (ary , dpt .usm_ndarray ):
425- raise TypeError
426- q = ary .sycl_queue
440+ raise TypeError (
441+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
442+ )
443+ exec_q = ary .sycl_queue
427444 usm_type = ary .usm_type
428- ary_np = dpt .asnumpy (ary )
429- nz = ary_np .nonzero ()
430- return tuple (dpt .asarray (i , usm_type = usm_type , sycl_queue = q ) for i in nz )
445+ mask_nelems = ary .size
446+ cumsum = dpt .empty (
447+ mask_nelems , dtype = dpt .int64 , sycl_queue = exec_q , order = "C"
448+ )
449+ mask_count = ti .mask_positions (ary , cumsum , sycl_queue = exec_q )
450+ indexes = dpt .empty (
451+ (ary .ndim , mask_count ),
452+ dtype = cumsum .dtype ,
453+ usm_type = usm_type ,
454+ sycl_queue = exec_q ,
455+ order = "C" ,
456+ )
457+ hev , _ = ti ._nonzero (cumsum , indexes , ary .shape , exec_q )
458+ res = tuple (indexes [i , :] for i in range (ary .ndim ))
459+ hev .wait ()
460+ return res
431461
432462
433463def _take_multi_index (ary , inds , p ):
@@ -473,34 +503,57 @@ def _take_multi_index(ary, inds, p):
473503 return res
474504
475505
476- def _mock_place (ary , ary_mask , p , vals ):
506+ def _place_impl (ary , ary_mask , vals , axis = 0 ):
507+ """Extract elements of ary by applying mask starting from slot
508+ dimension axis"""
477509 if not isinstance (ary , dpt .usm_ndarray ):
478- raise TypeError
510+ raise TypeError (
511+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
512+ )
479513 if not isinstance (ary_mask , dpt .usm_ndarray ):
480- raise TypeError
514+ raise TypeError (
515+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
516+ )
517+ if not isinstance (vals , dpt .usm_ndarray ):
518+ raise TypeError (
519+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
520+ )
481521 exec_q = dpctl .utils .get_execution_queue (
482- (ary .sycl_queue , ary_mask .sycl_queue )
522+ (ary .sycl_queue , ary_mask .sycl_queue , vals . sycl_queue )
483523 )
484- if exec_q is not None and isinstance (vals , dpt .usm_ndarray ):
485- exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
486524 if exec_q is None :
487525 raise dpctl .utils .ExecutionPlacementError (
488- "Can not automatically determine where to allocate the "
489- "result or performance execution. "
490- "Use `usm_ndarray.to_device` method to migrate data to "
491- "be associated with the same queue."
526+ "arrays have different associated queues. "
527+ "Use `Y.to_device(X.device)` to migrate."
492528 )
493-
494- ary_np = dpt .asnumpy (ary )
495- mask_np = dpt .asnumpy (ary_mask )
496- if isinstance (vals , dpt .usm_ndarray ) or hasattr (
497- vals , "__sycl_usm_array_interface__"
498- ):
499- vals_np = dpt .asnumpy (vals )
529+ ary_nd = ary .ndim
530+ pp = normalize_axis_index (operator .index (axis ), ary_nd )
531+ mask_nd = ary_mask .ndim
532+ if pp < 0 or pp + mask_nd > ary_nd :
533+ raise ValueError (
534+ "Parameter p is inconsistent with input array dimensions"
535+ )
536+ mask_nelems = ary_mask .size
537+ cumsum = dpt .empty (mask_nelems , dtype = dpt .int64 , device = ary_mask .device )
538+ exec_q = cumsum .sycl_queue
539+ mask_count = ti .mask_positions (ary_mask , cumsum , sycl_queue = exec_q )
540+ expected_vals_shape = (
541+ ary .shape [:pp ] + (mask_count ,) + ary .shape [pp + mask_nd :]
542+ )
543+ if vals .dtype == ary .dtype :
544+ rhs = vals
500545 else :
501- vals_np = vals
502- ary_np [(slice (None ),) * p + (mask_np ,)] = vals_np
503- ary [...] = ary_np
546+ rhs = dpt .astype (vals , ary .dtype )
547+ rhs = dpt .broadcast_to (rhs , expected_vals_shape )
548+ hev , _ = ti ._place (
549+ dst = ary ,
550+ cumsum = cumsum ,
551+ axis_start = pp ,
552+ axis_end = pp + mask_nd ,
553+ rhs = rhs ,
554+ sycl_queue = exec_q ,
555+ )
556+ hev .wait ()
504557 return
505558
506559
0 commit comments