@@ -1247,19 +1247,31 @@ def tril(X, k=0):
12471247
12481248 if k >= shape [nd - 1 ] - 1 :
12491249 res = dpt .empty (
1250- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1250+ X .shape ,
1251+ dtype = X .dtype ,
1252+ order = order ,
1253+ usm_type = X .usm_type ,
1254+ sycl_queue = X .sycl_queue ,
12511255 )
12521256 hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
12531257 src = X , dst = res , sycl_queue = X .sycl_queue
12541258 )
12551259 hev .wait ()
12561260 elif k < - shape [nd - 2 ]:
12571261 res = dpt .zeros (
1258- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1262+ X .shape ,
1263+ dtype = X .dtype ,
1264+ order = order ,
1265+ usm_type = X .usm_type ,
1266+ sycl_queue = X .sycl_queue ,
12591267 )
12601268 else :
12611269 res = dpt .empty (
1262- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1270+ X .shape ,
1271+ dtype = X .dtype ,
1272+ order = order ,
1273+ usm_type = X .usm_type ,
1274+ sycl_queue = X .sycl_queue ,
12631275 )
12641276 hev , _ = ti ._tril (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
12651277 hev .wait ()
@@ -1290,19 +1302,31 @@ def triu(X, k=0):
12901302
12911303 if k > shape [nd - 1 ]:
12921304 res = dpt .zeros (
1293- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1305+ X .shape ,
1306+ dtype = X .dtype ,
1307+ order = order ,
1308+ usm_type = X .usm_type ,
1309+ sycl_queue = X .sycl_queue ,
12941310 )
12951311 elif k <= - shape [nd - 2 ] + 1 :
12961312 res = dpt .empty (
1297- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1313+ X .shape ,
1314+ dtype = X .dtype ,
1315+ order = order ,
1316+ usm_type = X .usm_type ,
1317+ sycl_queue = X .sycl_queue ,
12981318 )
12991319 hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
13001320 src = X , dst = res , sycl_queue = X .sycl_queue
13011321 )
13021322 hev .wait ()
13031323 else :
13041324 res = dpt .empty (
1305- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1325+ X .shape ,
1326+ dtype = X .dtype ,
1327+ order = order ,
1328+ usm_type = X .usm_type ,
1329+ sycl_queue = X .sycl_queue ,
13061330 )
13071331 hev , _ = ti ._triu (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
13081332 hev .wait ()
0 commit comments