1515
1616#include " pybind11/complex.h"
1717#include " pybind11/pybind11.h"
18+ #include " pybind11/numpy.h"
1819
1920#ifndef FORCE_IMPORT_ARRAY
2021#define NO_IMPORT_ARRAY
@@ -129,8 +130,11 @@ namespace xt
129130
130131 namespace detail
131132 {
133+ template <class T , class E = void >
134+ struct numpy_traits ;
135+
132136 template <class T >
133- struct numpy_traits
137+ struct numpy_traits <T, std:: enable_if_t <pybind11::detail::satisfies_any_of<T, std::is_arithmetic, xtl::is_complex>::value>>
134138 {
135139 private:
136140
@@ -184,6 +188,47 @@ namespace xt
184188 {
185189 return numpy_enum_adjuster<NPY_LONGLONG != NPY_INT64>::pyarray_type (obj);
186190 }
191+
192+ template <class T >
193+ void default_initialize_impl (T& storage, std::false_type)
194+ {
195+ }
196+
197+ template <class T >
198+ void default_initialize_impl (T& storage, std::true_type)
199+ {
200+ using value_type = typename T::value_type;
201+ storage[0 ] = value_type{};
202+ }
203+
204+ template <class T >
205+ void default_initialize (T& storage)
206+ {
207+ using value_type = typename T::value_type;
208+ default_initialize_impl (storage, std::is_copy_assignable<value_type>());
209+ }
210+
211+ template <class T >
212+ bool check_array_type (const pybind11::handle& src, std::true_type)
213+ {
214+ int type_num = xt::detail::numpy_traits<T>::type_num;
215+ return xt::detail::pyarray_type (reinterpret_cast <PyArrayObject*>(src.ptr ())) == type_num;
216+ }
217+
218+ template <class T >
219+ bool check_array_type (const pybind11::handle& src, std::false_type)
220+ {
221+ return PyArray_EquivTypes ((PyArray_Descr*) pybind11::detail::array_proxy (src.ptr ())->descr ,
222+ (PyArray_Descr*) pybind11::dtype::of<T>().ptr ());
223+ }
224+
225+ template <class T >
226+ bool check_array (const pybind11::handle& src)
227+ {
228+ using is_arithmetic_type = std::integral_constant<bool , bool (pybind11::detail::satisfies_any_of<T, std::is_arithmetic, xtl::is_complex>::value)>;
229+ return PyArray_Check (src.ptr ()) &&
230+ check_array_type<T>(src, is_arithmetic_type{});
231+ }
187232 }
188233
189234 /* *****************************
@@ -232,9 +277,9 @@ namespace xt
232277 template <class D >
233278 inline bool pycontainer<D>::check_(pybind11::handle h)
234279 {
235- int type_num = detail::numpy_traits <value_type>::type_num ;
280+ auto dtype = pybind11:: detail::npy_format_descriptor <value_type>::dtype () ;
236281 return PyArray_Check (h.ptr ()) &&
237- PyArray_EquivTypenums (PyArray_TYPE (reinterpret_cast <PyArrayObject*>(h.ptr ())), type_num );
282+ PyArray_EquivTypes_ (PyArray_TYPE (reinterpret_cast <PyArrayObject*>(h.ptr ())), dtype. ptr () );
238283 }
239284
240285 template <class D >
@@ -244,8 +289,9 @@ namespace xt
244289 {
245290 return nullptr ;
246291 }
247- int type_num = detail::numpy_traits<value_type>::type_num;
248- auto res = PyArray_FromAny (ptr, PyArray_DescrFromType (type_num), 0 , 0 ,
292+
293+ auto dtype = pybind11::detail::npy_format_descriptor<value_type>::dtype ();
294+ auto res = PyArray_FromAny (ptr, (PyArray_Descr *) dtype.release ().ptr (), 0 , 0 ,
249295 NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST, nullptr );
250296 return res;
251297 }
0 commit comments