|
26 | 26 | #include <algorithm> |
27 | 27 | #include <complex> |
28 | 28 | #include <cstdint> |
29 | | -#include <pybind11/complex.h> |
30 | 29 | #include <pybind11/pybind11.h> |
31 | 30 | #include <pybind11/stl.h> |
32 | 31 | #include <thread> |
|
37 | 36 | #include "copy_and_cast_usm_to_usm.hpp" |
38 | 37 | #include "copy_for_reshape.hpp" |
39 | 38 | #include "copy_numpy_ndarray_into_usm_ndarray.hpp" |
| 39 | +#include "device_support_queries.hpp" |
40 | 40 | #include "eye_ctor.hpp" |
41 | 41 | #include "full_ctor.hpp" |
42 | 42 | #include "linear_sequences.hpp" |
@@ -102,36 +102,6 @@ void init_dispatch_vectors(void) |
102 | 102 | return; |
103 | 103 | } |
104 | 104 |
|
105 | | -std::string get_default_device_fp_type(sycl::device d) |
106 | | -{ |
107 | | - if (d.has(sycl::aspect::fp64)) { |
108 | | - return "f8"; |
109 | | - } |
110 | | - else { |
111 | | - return "f4"; |
112 | | - } |
113 | | -} |
114 | | - |
115 | | -std::string get_default_device_int_type(sycl::device) |
116 | | -{ |
117 | | - return "i8"; |
118 | | -} |
119 | | - |
120 | | -std::string get_default_device_complex_type(sycl::device d) |
121 | | -{ |
122 | | - if (d.has(sycl::aspect::fp64)) { |
123 | | - return "c16"; |
124 | | - } |
125 | | - else { |
126 | | - return "c8"; |
127 | | - } |
128 | | -} |
129 | | - |
130 | | -std::string get_default_device_bool_type(sycl::device) |
131 | | -{ |
132 | | - return "b1"; |
133 | | -} |
134 | | - |
135 | 105 | } // namespace |
136 | 106 |
|
137 | 107 | PYBIND11_MODULE(_tensor_impl, m) |
@@ -209,57 +179,43 @@ PYBIND11_MODULE(_tensor_impl, m) |
209 | 179 | py::arg("k"), py::arg("dst"), py::arg("sycl_queue"), |
210 | 180 | py::arg("depends") = py::list()); |
211 | 181 |
|
212 | | - m.def("default_device_fp_type", [](sycl::queue q) -> std::string { |
213 | | - return get_default_device_fp_type(q.get_device()); |
214 | | - }); |
215 | | - m.def("default_device_fp_type_device", [](sycl::device dev) -> std::string { |
216 | | - return get_default_device_fp_type(dev); |
217 | | - }); |
218 | | - |
219 | | - m.def("default_device_int_type", [](sycl::queue q) -> std::string { |
220 | | - return get_default_device_int_type(q.get_device()); |
221 | | - }); |
222 | | - m.def("default_device_int_type_device", |
223 | | - [](sycl::device dev) -> std::string { |
224 | | - return get_default_device_int_type(dev); |
225 | | - }); |
226 | | - |
227 | | - m.def("default_device_bool_type", [](sycl::queue q) -> std::string { |
228 | | - return get_default_device_bool_type(q.get_device()); |
229 | | - }); |
230 | | - m.def("default_device_bool_type_device", |
231 | | - [](sycl::device dev) -> std::string { |
232 | | - return get_default_device_bool_type(dev); |
233 | | - }); |
234 | | - |
235 | | - m.def("default_device_complex_type", [](sycl::queue q) -> std::string { |
236 | | - return get_default_device_complex_type(q.get_device()); |
237 | | - }); |
238 | | - m.def("default_device_complex_type_device", |
239 | | - [](sycl::device dev) -> std::string { |
240 | | - return get_default_device_complex_type(dev); |
241 | | - }); |
242 | | - m.def( |
243 | | - "_tril", |
244 | | - [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, |
245 | | - py::ssize_t k, sycl::queue exec_q, |
246 | | - const std::vector<sycl::event> depends) |
247 | | - -> std::pair<sycl::event, sycl::event> { |
248 | | - return usm_ndarray_triul(exec_q, src, dst, 'l', k, depends); |
249 | | - }, |
250 | | - "Tril helper function.", py::arg("src"), py::arg("dst"), |
251 | | - py::arg("k") = 0, py::arg("sycl_queue"), |
252 | | - py::arg("depends") = py::list()); |
| 182 | + m.def("default_device_fp_type", |
| 183 | + dpctl::tensor::py_internal::default_device_fp_type, |
| 184 | + "Gives default floating point type supported by device.", |
| 185 | + py::arg("dev")); |
| 186 | + |
| 187 | + m.def("default_device_int_type", |
| 188 | + dpctl::tensor::py_internal::default_device_int_type, |
| 189 | + "Gives default integer type supported by device.", py::arg("dev")); |
| 190 | + |
| 191 | + m.def("default_device_bool_type", |
| 192 | + dpctl::tensor::py_internal::default_device_bool_type, |
| 193 | + "Gives default boolean type supported by device.", py::arg("dev")); |
| 194 | + |
| 195 | + m.def("default_device_complex_type", |
| 196 | + dpctl::tensor::py_internal::default_device_complex_type, |
| 197 | + "Gives default complex floating point type support by device.", |
| 198 | + py::arg("dev")); |
| 199 | + |
| 200 | + auto tril_fn = [](dpctl::tensor::usm_ndarray src, |
| 201 | + dpctl::tensor::usm_ndarray dst, py::ssize_t k, |
| 202 | + sycl::queue exec_q, |
| 203 | + const std::vector<sycl::event> depends) |
| 204 | + -> std::pair<sycl::event, sycl::event> { |
| 205 | + return usm_ndarray_triul(exec_q, src, dst, 'l', k, depends); |
| 206 | + }; |
| 207 | + m.def("_tril", tril_fn, "Tril helper function.", py::arg("src"), |
| 208 | + py::arg("dst"), py::arg("k") = 0, py::arg("sycl_queue"), |
| 209 | + py::arg("depends") = py::list()); |
253 | 210 |
|
254 | | - m.def( |
255 | | - "_triu", |
256 | | - [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, |
257 | | - py::ssize_t k, sycl::queue exec_q, |
258 | | - const std::vector<sycl::event> depends) |
259 | | - -> std::pair<sycl::event, sycl::event> { |
260 | | - return usm_ndarray_triul(exec_q, src, dst, 'u', k, depends); |
261 | | - }, |
262 | | - "Triu helper function.", py::arg("src"), py::arg("dst"), |
263 | | - py::arg("k") = 0, py::arg("sycl_queue"), |
264 | | - py::arg("depends") = py::list()); |
| 211 | + auto triu_fn = [](dpctl::tensor::usm_ndarray src, |
| 212 | + dpctl::tensor::usm_ndarray dst, py::ssize_t k, |
| 213 | + sycl::queue exec_q, |
| 214 | + const std::vector<sycl::event> depends) |
| 215 | + -> std::pair<sycl::event, sycl::event> { |
| 216 | + return usm_ndarray_triul(exec_q, src, dst, 'u', k, depends); |
| 217 | + }; |
| 218 | + m.def("_triu", triu_fn, "Triu helper function.", py::arg("src"), |
| 219 | + py::arg("dst"), py::arg("k") = 0, py::arg("sycl_queue"), |
| 220 | + py::arg("depends") = py::list()); |
265 | 221 | } |
0 commit comments