@@ -39,23 +39,60 @@ namespace pybind11
3939namespace detail
4040{
4141
42+ #define DPCTL_TYPE_CASTER (type , py_name ) \
43+ protected: \
44+ std::unique_ptr<type> value; \
45+ \
46+ public: \
47+ static constexpr auto name = py_name; \
48+ template < \
49+ typename T_, \
50+ ::pybind11::detail::enable_if_t< \
51+ std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
52+ int> = 0> \
53+ static ::pybind11::handle cast(T_ *src, \
54+ ::pybind11::return_value_policy policy, \
55+ ::pybind11::handle parent) \
56+ { \
57+ if (!src) \
58+ return ::pybind11::none().release(); \
59+ if (policy == ::pybind11::return_value_policy::take_ownership) { \
60+ auto h = cast(std::move(*src), policy, parent); \
61+ delete src; \
62+ return h; \
63+ } \
64+ return cast(*src, policy, parent); \
65+ } \
66+ operator type *() \
67+ { \
68+ return value.get(); \
69+ } /* NOLINT(bugprone-macro-parentheses) */ \
70+ operator type & () \
71+ { \
72+ return * value ; \
73+ } /* NOLINT(bugprone-macro-parentheses) */ \
74+ operator type && ( ) && \
75+ { \
76+ return std ::move (* value ); \
77+ } /* NOLINT(bugprone-macro-parentheses) */ \
78+ template < typename T_ > \
79+ using cast_op_type = ::pybind11 ::detail ::movable_cast_op_type < T_ >
80+
4281/* This type caster associates ``sycl::queue`` C++ class with
4382 * :class:`dpctl.SyclQueue` for the purposes of generation of
4483 * Python bindings by pybind11.
4584 */
4685template < > struct type_caster < sycl ::queue >
4786{
4887public :
49- PYBIND11_TYPE_CASTER (sycl ::queue , _ ("dpctl.SyclQueue" ));
50-
5188 bool load (handle src , bool )
5289 {
5390 PyObject * source = src .ptr ();
5491 if (PyObject_TypeCheck (source , & PySyclQueueType )) {
5592 DPCTLSyclQueueRef QRef = SyclQueue_GetQueueRef (
5693 reinterpret_cast < PySyclQueueObject * > (source ));
57- sycl :: queue * q = reinterpret_cast < sycl ::queue * > ( QRef );
58- value = * q ;
94+ value = std :: make_unique < sycl ::queue > (
95+ * ( reinterpret_cast < sycl :: queue * > ( QRef ))) ;
5996 return true;
6097 }
6198 else {
@@ -69,6 +106,8 @@ template <> struct type_caster<sycl::queue>
69106 auto tmp = SyclQueue_Make (reinterpret_cast < DPCTLSyclQueueRef > (& src ));
70107 return handle (reinterpret_cast < PyObject * > (tmp ));
71108 }
109+
110+ DPCTL_TYPE_CASTER (sycl ::queue , _ ("dpctl.SyclQueue" ));
72111};
73112
74113/* This type caster associates ``sycl::device`` C++ class with
@@ -78,20 +117,14 @@ template <> struct type_caster<sycl::queue>
78117template < > struct type_caster < sycl ::device >
79118{
80119public :
81- PYBIND11_TYPE_CASTER (sycl ::device , _ ("dpctl.SyclDevice" ));
82-
83120 bool load (handle src , bool )
84121 {
85122 PyObject * source = src .ptr ();
86123 if (PyObject_TypeCheck (source , & PySyclDeviceType )) {
87124 DPCTLSyclDeviceRef DRef = SyclDevice_GetDeviceRef (
88125 reinterpret_cast < PySyclDeviceObject * > (source ));
89- sycl ::device * d = reinterpret_cast < sycl ::device * > (DRef );
90- value = * d ;
91- return true;
92- }
93- else if (source == Py_None ) {
94- value = sycl ::device {};
126+ value = std ::make_unique < sycl ::device > (
127+ * (reinterpret_cast < sycl ::device * > (DRef )));
95128 return true;
96129 }
97130 else {
@@ -105,6 +138,8 @@ template <> struct type_caster<sycl::device>
105138 auto tmp = SyclDevice_Make (reinterpret_cast < DPCTLSyclDeviceRef > (& src ));
106139 return handle (reinterpret_cast < PyObject * > (tmp ));
107140 }
141+
142+ DPCTL_TYPE_CASTER (sycl ::device , _ ("dpctl.SyclDevice" ));
108143};
109144
110145/* This type caster associates ``sycl::context`` C++ class with
@@ -114,16 +149,14 @@ template <> struct type_caster<sycl::device>
114149template < > struct type_caster < sycl ::context >
115150{
116151public :
117- PYBIND11_TYPE_CASTER (sycl ::context , _ ("dpctl.SyclContext" ));
118-
119152 bool load (handle src , bool )
120153 {
121154 PyObject * source = src .ptr ();
122155 if (PyObject_TypeCheck (source , & PySyclContextType )) {
123156 DPCTLSyclContextRef CRef = SyclContext_GetContextRef (
124157 reinterpret_cast < PySyclContextObject * > (source ));
125- sycl :: context * ctx = reinterpret_cast < sycl ::context * > ( CRef );
126- value = * ctx ;
158+ value = std :: make_unique < sycl ::context > (
159+ * ( reinterpret_cast < sycl :: context * > ( CRef ))) ;
127160 return true;
128161 }
129162 else {
@@ -138,6 +171,8 @@ template <> struct type_caster<sycl::context>
138171 SyclContext_Make (reinterpret_cast < DPCTLSyclContextRef > (& src ));
139172 return handle (reinterpret_cast < PyObject * > (tmp ));
140173 }
174+
175+ DPCTL_TYPE_CASTER (sycl ::context , _ ("dpctl.SyclContext" ));
141176};
142177
143178/* This type caster associates ``sycl::event`` C++ class with
@@ -147,16 +182,14 @@ template <> struct type_caster<sycl::context>
147182template < > struct type_caster < sycl ::event >
148183{
149184public :
150- PYBIND11_TYPE_CASTER (sycl ::event , _ ("dpctl.SyclEvent" ));
151-
152185 bool load (handle src , bool )
153186 {
154187 PyObject * source = src .ptr ();
155188 if (PyObject_TypeCheck (source , & PySyclEventType )) {
156189 DPCTLSyclEventRef ERef = SyclEvent_GetEventRef (
157190 reinterpret_cast < PySyclEventObject * > (source ));
158- sycl :: event * ev = reinterpret_cast < sycl ::event * > ( ERef );
159- value = * ev ;
191+ value = std :: make_unique < sycl ::event > (
192+ * ( reinterpret_cast < sycl :: event * > ( ERef ))) ;
160193 return true;
161194 }
162195 else {
@@ -170,12 +203,102 @@ template <> struct type_caster<sycl::event>
170203 auto tmp = SyclEvent_Make (reinterpret_cast < DPCTLSyclEventRef > (& src ));
171204 return handle (reinterpret_cast < PyObject * > (tmp ));
172205 }
206+
207+ DPCTL_TYPE_CASTER (sycl ::event , _ ("dpctl.SyclEvent" ));
173208};
174209} // namespace detail
175210} // namespace pybind11
176211
177212namespace dpctl
178213{
214+
215+ namespace detail
216+ {
217+
218+ struct dpctl_api
219+ {
220+ public :
221+ static dpctl_api & get ()
222+ {
223+ static dpctl_api api ;
224+ return api ;
225+ }
226+
227+ py ::object sycl_queue_ ()
228+ {
229+ return * sycl_queue ;
230+ }
231+ py ::object default_usm_memory_ ()
232+ {
233+ return * default_usm_memory ;
234+ }
235+ py ::object default_usm_ndarray_ ()
236+ {
237+ return * default_usm_ndarray ;
238+ }
239+ py ::object as_usm_memory_ ()
240+ {
241+ return * as_usm_memory ;
242+ }
243+
244+ private :
245+ struct Deleter
246+ {
247+ void operator ()(py ::object * p ) const
248+ {
249+ bool guard = (Py_IsInitialized () && !_Py_IsFinalizing ());
250+
251+ if (guard ) {
252+ delete p ;
253+ }
254+ }
255+ };
256+
257+ std ::shared_ptr < py ::object > sycl_queue ;
258+ std ::shared_ptr < py ::object > default_usm_memory ;
259+ std ::shared_ptr < py ::object > default_usm_ndarray ;
260+ std ::shared_ptr < py ::object > as_usm_memory ;
261+
262+ dpctl_api () : sycl_queue {}, default_usm_memory {}, default_usm_ndarray {}
263+ {
264+ import_dpctl ();
265+
266+ sycl ::queue q_ ;
267+ py ::object py_sycl_queue = py ::cast (q_ );
268+ sycl_queue = std ::shared_ptr < py ::object > (new py ::object {py_sycl_queue },
269+ Deleter {});
270+
271+ py ::module_ mod_memory = py ::module_ ::import ("dpctl.memory" );
272+ py ::object py_as_usm_memory = mod_memory .attr ("as_usm_memory" );
273+ as_usm_memory = std ::shared_ptr < py ::object > (
274+ new py ::object {py_as_usm_memory }, Deleter {});
275+
276+ auto mem_kl = mod_memory .attr ("MemoryUSMHost" );
277+ py ::object py_default_usm_memory =
278+ mem_kl (1 , py ::arg ("queue" ) = py_sycl_queue );
279+ default_usm_memory = std ::shared_ptr < py ::object > (
280+ new py ::object {py_default_usm_memory }, Deleter {});
281+
282+ py ::module_ mod_usmarray =
283+ py ::module_ ::import ("dpctl.tensor._usmarray" );
284+ auto tensor_kl = mod_usmarray .attr ("usm_ndarray" );
285+
286+ py ::object py_default_usm_ndarray =
287+ tensor_kl (py ::tuple (), py ::arg ("dtype" ) = py ::str ("u1" ),
288+ py ::arg ("buffer" ) = py_default_usm_memory );
289+
290+ default_usm_ndarray = std ::shared_ptr < py ::object > (
291+ new py ::object {py_default_usm_ndarray }, Deleter {});
292+ }
293+
294+ public :
295+ dpctl_api (dpctl_api const & ) = delete ;
296+ void operator = (dpctl_api const & ) = delete ;
297+ ~dpctl_api (){};
298+ };
299+
300+ } // namespace detail
301+
179302namespace memory
180303{
181304
@@ -232,7 +355,9 @@ class usm_memory : public py::object
232355 }
233356 // END_TOKEN
234357
235- usm_memory () : py ::object (default_constructed (), stolen_t {})
358+ usm_memory ()
359+ : py ::object (::dpctl ::detail ::dpctl_api ::get ().default_usm_memory_ (),
360+ borrowed_t {})
236361 {
237362 if (!m_ptr )
238363 throw py ::error_already_set ();
@@ -267,26 +392,12 @@ class usm_memory : public py::object
267392 "cannot create a usm_memory from a nullptr" );
268393 return nullptr ;
269394 }
270- py ::module_ m = py ::module_ ::import ("dpctl.memory" );
271- auto convertor = m .attr ("as_usm_memory" );
272395
273- py ::object res ;
274- try {
275- res = convertor (py ::handle (o ));
276- } catch (const py ::error_already_set & e ) {
277- return nullptr ;
278- }
279- return res .ptr ();
280- }
396+ auto convertor = ::dpctl ::detail ::dpctl_api ::get ().as_usm_memory_ ();
281397
282- static PyObject * default_constructed ()
283- {
284- py ::module_ m = py ::module_ ::import ("dpctl.memory" );
285- auto kl = m .attr ("MemoryUSMDevice" );
286398 py ::object res ;
287399 try {
288- // allocate 1 byte
289- res = kl (1 );
400+ res = convertor (py ::handle (o ));
290401 } catch (const py ::error_already_set & e ) {
291402 return nullptr ;
292403 }
@@ -295,10 +406,7 @@ class usm_memory : public py::object
295406};
296407
297408} // end namespace memory
298- } // end namespace dpctl
299409
300- namespace dpctl
301- {
302410namespace tensor
303411{
304412class usm_ndarray : public py ::object
@@ -349,7 +457,9 @@ class usm_ndarray : public py::object
349457 }
350458 // END_TOKEN
351459
352- usm_ndarray () : py ::object (default_constructed (), stolen_t {})
460+ usm_ndarray ()
461+ : py ::object (::dpctl ::detail ::dpctl_api ::get ().default_usm_ndarray_ (),
462+ borrowed_t {})
353463 {
354464 if (!m_ptr )
355465 throw py ::error_already_set ();
@@ -481,21 +591,6 @@ class usm_ndarray : public py::object
481591
482592 return UsmNDArray_GetElementSize (raw_ar );
483593 }
484-
485- private :
486- static PyObject * default_constructed ()
487- {
488- py ::module_ m = py ::module_ ::import ("dpctl.tensor" );
489- auto kl = m .attr ("usm_ndarray" );
490- py ::object res ;
491- try {
492- // allocate 1 byte
493- res = kl (py ::make_tuple (), py ::arg ("dtype" ) = "u1" );
494- } catch (const py ::error_already_set & e ) {
495- return nullptr ;
496- }
497- return res .ptr ();
498- }
499594};
500595
501596} // end namespace tensor
0 commit comments