@@ -366,11 +366,11 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
366366 skip_if_dtype_not_supported (op1_dtype , q )
367367 skip_if_dtype_not_supported (op2_dtype , q )
368368
369- if dpt .can_cast (op2_dtype , op1_dtype , casting = "safe" ):
370- sz = 127
371- ar1 = dpt .ones (sz , dtype = op1_dtype )
372- ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
369+ sz = 127
370+ ar1 = dpt .ones (sz , dtype = op1_dtype )
371+ ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
373372
373+ if dpt .can_cast (op2_dtype , op1_dtype , casting = "safe" ):
374374 ar1 += ar2
375375 assert (
376376 dpt .asnumpy (ar1 ) == np .full (ar1 .shape , 2 , dtype = ar1 .dtype )
@@ -385,7 +385,8 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
385385 ).all ()
386386
387387 else :
388- assert pytest .raises (TypeError )
388+ with pytest .raises (TypeError ):
389+ ar1 += ar2
389390
390391
391392def test_add_inplace_broadcasting ():
@@ -396,3 +397,40 @@ def test_add_inplace_broadcasting():
396397
397398 m += v
398399 assert (dpt .asnumpy (m ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
400+
401+
402+ def test_add_inplace_errors ():
403+ get_queue_or_skip ()
404+ try :
405+ gpu_queue = dpctl .SyclQueue ("gpu" )
406+ except dpctl .SyclQueueCreationError :
407+ pytest .skip ("SyclQueue('gpu') failed, skipping" )
408+ try :
409+ cpu_queue = dpctl .SyclQueue ("cpu" )
410+ except dpctl .SyclQueueCreationError :
411+ pytest .skip ("SyclQueue('cpu') failed, skipping" )
412+
413+ ar1 = dpt .ones (2 , dtype = "float32" , sycl_queue = gpu_queue )
414+ ar2 = dpt .ones_like (ar1 , sycl_queue = cpu_queue )
415+ with pytest .raises (ExecutionPlacementError ):
416+ ar1 += ar2
417+
418+ ar1 = dpt .ones (2 , dtype = "float32" )
419+ ar2 = dpt .ones (3 , dtype = "float32" )
420+ with pytest .raises (ValueError ):
421+ ar1 += ar2
422+
423+ ar1 = np .ones (2 , dtype = "float32" )
424+ ar2 = dpt .ones (2 , dtype = "float32" )
425+ with pytest .raises (TypeError ):
426+ ar1 += ar2
427+
428+ ar1 = dpt .ones (2 , dtype = "float32" )
429+ ar2 = dict ()
430+ with pytest .raises (ValueError ):
431+ ar1 += ar2
432+
433+ ar1 = dpt .ones ((2 , 1 ), dtype = "float32" )
434+ ar2 = dpt .ones ((1 , 2 ), dtype = "float32" )
435+ with pytest .raises (ValueError ):
436+ ar1 += ar2
0 commit comments