@@ -483,21 +483,22 @@ def _coerce_and_infer_dt(*args, dt, sycl_queue, err_msg, allow_bool=False):
483483
484484def _round_for_arange (tmp ):
485485 k = int (tmp )
486- if k > 0 and float (k ) < tmp :
486+ if k >= 0 and float (k ) < tmp :
487487 tmp = tmp + 1
488488 return tmp
489489
490490
491491def _get_arange_length (start , stop , step ):
492492 "Compute length of arange sequence"
493493 span = stop - start
494- if type (step ) in [ int , float ] and type (span ) in [ int , float ] :
494+ if hasattr (step , "__float__" ) and hasattr (span , "__float__" ) :
495495 return _round_for_arange (span / step )
496496 tmp = span / step
497- if type (tmp ) is complex and tmp .imag == 0 :
497+ if hasattr (tmp , "__complex__" ):
498+ tmp = complex (tmp )
498499 tmp = tmp .real
499500 else :
500- return tmp
501+ tmp = float ( tmp )
501502 return _round_for_arange (tmp )
502503
503504
@@ -536,13 +537,18 @@ def arange(
536537 if stop is None :
537538 stop = start
538539 start = 0
540+ if step is None :
541+ step = 1
539542 dpctl .utils .validate_usm_type (usm_type , allow_none = False )
540543 sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
541- (start , stop , step ,), dt = _coerce_and_infer_dt (
544+ is_bool = False
545+ if dtype :
546+ is_bool = (dtype is bool ) or (dpt .dtype (dtype ) == dpt .bool )
547+ (start_ , stop_ , step_ ), dt = _coerce_and_infer_dt (
542548 start ,
543549 stop ,
544550 step ,
545- dt = dtype ,
551+ dt = dpt . int8 if is_bool else dtype ,
546552 sycl_queue = sycl_queue ,
547553 err_msg = "start, stop, and step must be Python scalars" ,
548554 allow_bool = False ,
@@ -554,18 +560,40 @@ def arange(
554560 sh = 0
555561 except TypeError :
556562 sh = 0
563+ if is_bool and sh > 2 :
564+ raise ValueError ("no fill-function for boolean data type" )
557565 res = dpt .usm_ndarray (
558566 (sh ,),
559567 dtype = dt ,
560568 buffer = usm_type ,
561569 order = "C" ,
562570 buffer_ctor_kwargs = {"queue" : sycl_queue },
563571 )
564- _step = (start + step ) - start
565- _step = dt .type (_step )
566- _start = dt .type (start )
572+ sc_ty = dt .type
573+ _first = sc_ty (start )
574+ if sh > 1 :
575+ _second = sc_ty (start + step )
576+ if dt in [dpt .uint8 , dpt .uint16 , dpt .uint32 , dpt .uint64 ]:
577+ int64_ty = dpt .int64 .type
578+ _step = int64_ty (_second ) - int64_ty (_first )
579+ else :
580+ _step = _second - _first
581+ _step = sc_ty (_step )
582+ else :
583+ _step = sc_ty (1 )
584+ _start = _first
567585 hev , _ = ti ._linspace_step (_start , _step , res , sycl_queue )
568586 hev .wait ()
587+ if is_bool :
588+ res_out = dpt .usm_ndarray (
589+ (sh ,),
590+ dtype = dpt .bool ,
591+ buffer = usm_type ,
592+ order = "C" ,
593+ buffer_ctor_kwargs = {"queue" : sycl_queue },
594+ )
595+ res_out [:] = res
596+ res = res_out
569597 return res
570598
571599
0 commit comments