@@ -1021,6 +1021,22 @@ def ones(
10211021 return res
10221022
10231023
1024+ def _cast_fill_val (fill_val , dt ):
1025+ """
1026+ Casts the Python scalar `fill_val` to another Python type coercible to the
1027+ requested data type `dt`, if necessary.
1028+ """
1029+ val_type = type (fill_val )
1030+ if val_type in [float , complex ] and np .issubdtype (dt , np .integer ):
1031+ return int (fill_val .real )
1032+ elif val_type is complex and np .issubdtype (dt , np .floating ):
1033+ return fill_val .real
1034+ elif val_type is int and np .issubdtype (dt , np .integer ):
1035+ return _to_scalar (fill_val , dt )
1036+ else :
1037+ return fill_val
1038+
1039+
10241040def full (
10251041 shape ,
10261042 fill_value ,
@@ -1097,21 +1113,15 @@ def full(
10971113
10981114 sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
10991115 usm_type = usm_type if usm_type is not None else "device"
1100- fill_value_type = type (fill_value )
1101- dtype = _get_dtype (dtype , sycl_queue , ref_type = fill_value_type )
1116+ dtype = _get_dtype (dtype , sycl_queue , ref_type = type (fill_value ))
11021117 res = dpt .usm_ndarray (
11031118 shape ,
11041119 dtype = dtype ,
11051120 buffer = usm_type ,
11061121 order = order ,
11071122 buffer_ctor_kwargs = {"queue" : sycl_queue },
11081123 )
1109- if fill_value_type in [float , complex ] and np .issubdtype (dtype , np .integer ):
1110- fill_value = int (fill_value .real )
1111- elif fill_value_type is complex and np .issubdtype (dtype , np .floating ):
1112- fill_value = fill_value .real
1113- elif fill_value_type is int and np .issubdtype (dtype , np .integer ):
1114- fill_value = _to_scalar (fill_value , dtype )
1124+ fill_value = _cast_fill_val (fill_value , dtype )
11151125
11161126 _manager = dpctl .utils .SequentialOrderManager [sycl_queue ]
11171127 # populating new allocation, no dependent events
@@ -1479,26 +1489,15 @@ def full_like(
14791489 )
14801490 _manager .add_event_pair (hev , copy_ev )
14811491 return res
1482- else :
1483- fill_value_type = type (fill_value )
1484- dtype = _get_dtype (dtype , sycl_queue , ref_type = fill_value_type )
1485- res = _empty_like_orderK (x , dtype , usm_type , sycl_queue )
1486- if fill_value_type in [float , complex ] and np .issubdtype (
1487- dtype , np .integer
1488- ):
1489- fill_value = int (fill_value .real )
1490- elif fill_value_type is complex and np .issubdtype (
1491- dtype , np .floating
1492- ):
1493- fill_value = fill_value .real
1494- elif fill_value_type is int and np .issubdtype (dtype , np .integer ):
1495- fill_value = _to_scalar (fill_value , dtype )
14961492
1497- _manager = dpctl .utils .SequentialOrderManager [sycl_queue ]
1498- # populating new allocation, no dependent events
1499- hev , full_ev = ti ._full_usm_ndarray (fill_value , res , sycl_queue )
1500- _manager .add_event_pair (hev , full_ev )
1501- return res
1493+ dtype = _get_dtype (dtype , sycl_queue , ref_type = type (fill_value ))
1494+ res = _empty_like_orderK (x , dtype , usm_type , sycl_queue )
1495+ fill_value = _cast_fill_val (fill_value , dtype )
1496+ _manager = dpctl .utils .SequentialOrderManager [sycl_queue ]
1497+ # populating new allocation, no dependent events
1498+ hev , full_ev = ti ._full_usm_ndarray (fill_value , res , sycl_queue )
1499+ _manager .add_event_pair (hev , full_ev )
1500+ return res
15021501 else :
15031502 return full (
15041503 sh ,
0 commit comments