@@ -86,20 +86,23 @@ def test_get_current_device_type_outside_device_ctxt (self):
8686 def test_get_current_device_type_inside_device_ctxt (self ):
8787 self .assertEqual (dpctl .get_current_device_type (), None )
8888
89- with dpctl .device_context (dpctl . device_type . gpu ):
89+ with dpctl .device_context ("opencl: gpu:0" ):
9090 self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .gpu )
9191
9292 self .assertEqual (dpctl .get_current_device_type (), None )
9393
94- @unittest .skipIf (not dpctl .has_cpu_queues (), "No CPU platforms available" )
94+ @unittest .skipUnless (dpctl .get_num_queues (backend_ty = "opencl" ,
95+ device_ty = "cpu" ) > 0 ,
96+ "No OpenCL CPU queues available" )
9597 def test_get_current_device_type_inside_nested_device_ctxt (self ):
9698 self .assertEqual (dpctl .get_current_device_type (), None )
9799
98- with dpctl .device_context (dpctl . device_type . cpu ):
100+ with dpctl .device_context ("opencl: cpu:0" ):
99101 self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .cpu )
100102
101- with dpctl .device_context (dpctl .device_type .gpu ):
102- self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .gpu )
103+ with dpctl .device_context ("opencl:gpu:0" ):
104+ self .assertEqual (dpctl .get_current_device_type (), dpctl .
105+ device_type .gpu )
103106 self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .cpu )
104107
105108 self .assertEqual (dpctl .get_current_device_type (), None )
0 commit comments