@@ -10,80 +10,78 @@ namespace xgboost {
1010namespace sycl {
1111
1212::sycl::queue* DeviceManager::GetQueue (const DeviceOrd& device_spec) const {
13- if (!device_spec.IsSycl ()) {
14- LOG (WARNING) << " Sycl kernel is executed with non-sycl context: "
15- << device_spec.Name () << " . "
16- << " Default sycl device_selector will be used." ;
17- }
13+ if (!device_spec.IsSycl ()) {
14+ LOG (WARNING) << " Sycl kernel is executed with non-sycl context: " << device_spec.Name () << " . "
15+ << " Default sycl device_selector will be used." ;
16+ }
1817
19- size_t queue_idx;
20- bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal ) ||
21- (collective::IsDistributed ());
22- DeviceRegister& device_register = GetDevicesRegister ();
23- if (not_use_default_selector) {
24- if (device_spec.IsSyclDefault ()) {
25- auto & devices = device_register.devices ;
26- const int device_idx = collective::IsDistributed ()
27- ? collective::GetRank () % devices.size ()
28- : device_spec.ordinal ;
29- CHECK_LT (device_idx, devices.size ());
30- queue_idx = device_idx;
31- } else if (device_spec.IsSyclCPU ()) {
32- auto & cpu_devices_idxes = device_register.cpu_devices_idxes ;
33- const int device_idx = collective::IsDistributed ()
34- ? collective::GetRank () % cpu_devices_idxes.size ()
35- : device_spec.ordinal ;
36- CHECK_LT (device_idx, cpu_devices_idxes.size ());
37- queue_idx = cpu_devices_idxes[device_idx];
38- } else if (device_spec.IsSyclGPU ()) {
39- auto & gpu_devices_idxes = device_register.gpu_devices_idxes ;
40- const int device_idx = collective::IsDistributed ()
41- ? collective::GetRank () % gpu_devices_idxes.size ()
42- : device_spec.ordinal ;
43- CHECK_LT (device_idx, gpu_devices_idxes.size ());
44- queue_idx = gpu_devices_idxes[device_idx];
45- } else {
46- LOG (WARNING) << device_spec << " is not sycl, sycl:cpu or sycl:gpu" ;
47- auto device = ::sycl::queue (::sycl::default_selector_v).get_device ();
48- queue_idx = device_register.devices .at (device);
49- }
18+ size_t queue_idx;
19+ bool not_use_default_selector =
20+ (device_spec.ordinal != kDefaultOrdinal ) || (collective::IsDistributed ());
21+ DeviceRegister& device_register = GetDevicesRegister ();
22+ if (not_use_default_selector) {
23+ if (device_spec.IsSyclDefault ()) {
24+ auto & devices = device_register.devices ;
25+ const int device_idx = collective::IsDistributed () ? collective::GetRank () % devices.size ()
26+ : device_spec.ordinal ;
27+ CHECK_LT (device_idx, devices.size ());
28+ queue_idx = device_idx;
29+ } else if (device_spec.IsSyclCPU ()) {
30+ auto & cpu_devices_idxes = device_register.cpu_devices_idxes ;
31+ const int device_idx = collective::IsDistributed ()
32+ ? collective::GetRank () % cpu_devices_idxes.size ()
33+ : device_spec.ordinal ;
34+ CHECK_LT (device_idx, cpu_devices_idxes.size ());
35+ queue_idx = cpu_devices_idxes[device_idx];
36+ } else if (device_spec.IsSyclGPU ()) {
37+ auto & gpu_devices_idxes = device_register.gpu_devices_idxes ;
38+ const int device_idx = collective::IsDistributed ()
39+ ? collective::GetRank () % gpu_devices_idxes.size ()
40+ : device_spec.ordinal ;
41+ CHECK_LT (device_idx, gpu_devices_idxes.size ());
42+ queue_idx = gpu_devices_idxes[device_idx];
43+ } else {
44+ LOG (WARNING) << device_spec << " is not sycl, sycl:cpu or sycl:gpu" ;
45+ auto device = ::sycl::queue (::sycl::default_selector_v).get_device ();
46+ queue_idx = device_register.devices .at (device);
47+ }
48+ } else {
49+ if (device_spec.IsSyclCPU ()) {
50+ auto device = ::sycl::queue (::sycl::cpu_selector_v).get_device ();
51+ queue_idx = device_register.devices .at (device);
52+ } else if (device_spec.IsSyclGPU ()) {
53+ auto device = ::sycl::queue (::sycl::gpu_selector_v).get_device ();
54+ queue_idx = device_register.devices .at (device);
5055 } else {
51- if (device_spec.IsSyclCPU ()) {
52- auto device = ::sycl::queue (::sycl::cpu_selector_v).get_device ();
53- queue_idx = device_register.devices .at (device);
54- } else if (device_spec.IsSyclGPU ()) {
55- auto device = ::sycl::queue (::sycl::gpu_selector_v).get_device ();
56- queue_idx = device_register.devices .at (device);
57- } else {
58- auto device = ::sycl::queue (::sycl::default_selector_v).get_device ();
59- queue_idx = device_register.devices .at (device);
60- }
56+ auto device = ::sycl::queue (::sycl::default_selector_v).get_device ();
57+ queue_idx = device_register.devices .at (device);
6158 }
62- return &(device_register.queues [queue_idx]);
59+ }
60+ return &(device_register.queues [queue_idx]);
6361}
6462
6563DeviceManager::DeviceRegister& DeviceManager::GetDevicesRegister () const {
66- static DeviceRegister device_register;
64+ static DeviceRegister device_register;
6765
68- if (device_register.devices .size () == 0 ) {
69- std::lock_guard<std::mutex> guard (device_registering_mutex);
70- std::vector<::sycl::device> devices = ::sycl::device::get_devices ();
71- for (size_t i = 0 ; i < devices.size (); i++) {
72- LOG (INFO) << " device_index = " << i << " , name = "
73- << devices[i].get_info <::sycl::info::device::name>();
74- }
66+ if (device_register.devices .size () == 0 ) {
67+ std::lock_guard<std::mutex> guard (device_registering_mutex);
68+ std::vector<::sycl::device> devices = ::sycl::device::get_devices ();
69+ for (size_t i = 0 ; i < devices.size (); i++) {
70+ LOG (INFO) << " device_index = " << i
71+ << " , name = " << devices[i].get_info <::sycl::info::device::name>();
72+ }
7573
76- for (size_t i = 0 ; i < devices.size (); i++) {
77- device_register.devices [devices[i]] = i;
78- device_register.queues .push_back (::sycl::queue (devices[i]));
79- if (devices[i].is_cpu ()) {
80- device_register.cpu_devices_idxes .push_back (i);
81- } else if (devices[i].is_gpu ()) {
82- device_register.gpu_devices_idxes .push_back (i);
83- }
84- }
74+ for (size_t i = 0 ; i < devices.size (); i++) {
75+ device_register.devices [devices[i]] = i;
76+ device_register.queues .push_back (::sycl::queue (devices[i]));
77+ if (devices[i].is_cpu ()) {
78+ device_register.cpu_devices_idxes .push_back (i);
79+ } else if (devices[i].is_gpu ()) {
80+ device_register.gpu_devices_idxes .push_back (i);
81+ }
8582 }
86- return device_register;
83+ }
84+ return device_register;
8785}
8886
8987} // namespace sycl
0 commit comments