11/*
22 xtensor-python/xtensor_type_caster.hpp: Transparent conversion for xtensor and xarray
33
4- This code was inspired by the following code written by Wenzei Jakob
4+ This code is based on the following code written by Wenzei Jakob
55
66 pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
77
1919#include < pybind11/pybind11.h>
2020#include < pybind11/numpy.h>
2121
22- NAMESPACE_BEGIN (PYBIND11_NAMESPACE)
23- NAMESPACE_BEGIN(detail)
24-
25-
26- // Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
27- // otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
28- template<typename Type>
29- handle xtensor_array_cast(Type const &src, handle base = handle(), bool writeable = true) {
30- std::vector<ssize_t > python_strides (src.strides ().size ());
31- std::transform (src.strides ().begin (), src.strides ().end (), python_strides.data (),
32- [](auto v) { return sizeof (typename Type::value_type) * v; });
33-
34- array a (src.shape (), python_strides, src.begin (), base);
35-
36- if (!writeable)
37- array_proxy (a.ptr ())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
38-
39- return a.release ();
40- }
41-
42-
43- // Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
44- // reference the xtensor object's data with `base` as the python-registered base class (if omitted,
45- // the base will be set to None, and lifetime management is up to the caller). The numpy array is
46- // non-writeable if the given type is const.
47- template <typename Type, typename CType>
48- handle xtensor_ref_array (CType &src, handle parent = none()) {
49- return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
50- }
51-
52-
53- // Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
54- // array that references the encapsulated data with a python-side reference to the capsule to tie
55- // its destruction to that of any dependent python objects. Const-ness is determined by whether or
56- // not the CType of the pointer given is const.
57- template <typename Type, typename CType>
58- handle xtensor_encapsulate (CType *src) {
59- capsule base (src, [](void *o) { delete static_cast <CType *>(o); });
60- return xtensor_ref_array<Type>(*src, base);
61- }
62-
63- // Base class of type_caster for xtensor and xarray
64- template <class Type >
65- struct xtensor_type_caster_base {
66- bool load (handle src, bool ) {
67- return false ;
68- }
69-
70- private:
71- // Cast implementation
72- template <typename CType>
73- static handle cast_impl (CType *src, return_value_policy policy, handle parent) {
74- switch (policy) {
75- case return_value_policy::take_ownership:
76- case return_value_policy::automatic:
77- return xtensor_encapsulate<Type>(src);
78- case return_value_policy::move:
79- return xtensor_encapsulate<Type>(new CType (std::move (*src)));
80- case return_value_policy::copy:
81- return xtensor_array_cast<Type>(*src);
82- case return_value_policy::reference:
83- case return_value_policy::automatic_reference:
84- return xtensor_ref_array<Type>(*src);
85- case return_value_policy::reference_internal:
86- return xtensor_ref_array<Type>(*src, parent);
87- default :
88- throw cast_error (" unhandled return_value_policy: should not happen!" );
22+ namespace pybind11
23+ {
24+ namespace detail
25+ {
26+ // Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
27+ // otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
28+ template <typename Type>
29+ handle xtensor_array_cast (Type const &src, handle base = handle(), bool writeable = true)
30+ {
31+ std::vector<ssize_t > python_strides (src.strides ().size ());
32+ std::transform (src.strides ().begin (), src.strides ().end (), python_strides.data (),
33+ [](auto v) { return sizeof (typename Type::value_type) * v; });
34+
35+ array a (src.shape (), python_strides, src.begin (), base);
36+
37+ if (!writeable)
38+ {
39+ array_proxy (a.ptr ())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
40+ }
41+
42+ return a.release ();
43+ }
44+
45+ // Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
46+ // reference the xtensor object's data with `base` as the python-registered base class (if omitted,
47+ // the base will be set to None, and lifetime management is up to the caller). The numpy array is
48+ // non-writeable if the given type is const.
49+ template <typename Type, typename CType>
50+ handle xtensor_ref_array (CType &src, handle parent = none())
51+ {
52+ return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
53+ }
54+
55+ // Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
56+ // array that references the encapsulated data with a python-side reference to the capsule to tie
57+ // its destruction to that of any dependent python objects. Const-ness is determined by whether or
58+ // not the CType of the pointer given is const.
59+ template <typename Type, typename CType>
60+ handle xtensor_encapsulate (CType *src)
61+ {
62+ capsule base (src, [](void *o) { delete static_cast <CType *>(o); });
63+ return xtensor_ref_array<Type>(*src, base);
64+ }
65+
66+ // Base class of type_caster for xtensor and xarray
67+ template <class Type >
68+ struct xtensor_type_caster_base
69+ {
70+ bool load (handle src, bool )
71+ {
72+ return false ;
73+ }
74+
75+ private:
76+
77+ // Cast implementation
78+ template <typename CType>
79+ static handle cast_impl (CType *src, return_value_policy policy, handle parent)
80+ {
81+ switch (policy)
82+ {
83+ case return_value_policy::take_ownership:
84+ case return_value_policy::automatic:
85+ return xtensor_encapsulate<Type>(src);
86+ case return_value_policy::move:
87+ return xtensor_encapsulate<Type>(new CType (std::move (*src)));
88+ case return_value_policy::copy:
89+ return xtensor_array_cast<Type>(*src);
90+ case return_value_policy::reference:
91+ case return_value_policy::automatic_reference:
92+ return xtensor_ref_array<Type>(*src);
93+ case return_value_policy::reference_internal:
94+ return xtensor_ref_array<Type>(*src, parent);
95+ default :
96+ throw cast_error (" unhandled return_value_policy: should not happen!" );
97+ };
98+ }
99+
100+ public:
101+
102+ // Normal returned non-reference, non-const value:
103+ static handle cast (Type &&src, return_value_policy /* policy */ , handle parent)
104+ {
105+ return cast_impl (&src, return_value_policy::move, parent);
106+ }
107+
108+ // If you return a non-reference const, we mark the numpy array readonly:
109+ static handle cast (const Type &&src, return_value_policy /* policy */ , handle parent)
110+ {
111+ return cast_impl (&src, return_value_policy::move, parent);
112+ }
113+
114+ // lvalue reference return; default (automatic) becomes copy
115+ static handle cast (Type &src, return_value_policy policy, handle parent)
116+ {
117+ if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
118+ {
119+ policy = return_value_policy::copy;
120+ }
121+
122+ return cast_impl (&src, policy, parent);
123+ }
124+
125+ // const lvalue reference return; default (automatic) becomes copy
126+ static handle cast (const Type &src, return_value_policy policy, handle parent)
127+ {
128+ if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
129+ {
130+ policy = return_value_policy::copy;
131+ }
132+
133+ return cast (&src, policy, parent);
134+ }
135+
136+ // non-const pointer return
137+ static handle cast (Type *src, return_value_policy policy, handle parent)
138+ {
139+ return cast_impl (src, policy, parent);
140+ }
141+
142+ // const pointer return
143+ static handle cast (const Type *src, return_value_policy policy, handle parent)
144+ {
145+ return cast_impl (src, policy, parent);
146+ }
147+
148+ static PYBIND11_DESCR name ()
149+ {
150+ return _ (" xt::xtensor" );
151+ }
152+
153+ template <typename T>
154+ using cast_op_type = movable_cast_op_type<T>;
89155 };
90156 }
91- public:
92- // Normal returned non-reference, non-const value:
93- static handle cast (Type &&src, return_value_policy /* policy */ , handle parent) {
94- return cast_impl (&src, return_value_policy::move, parent);
95- }
96- // If you return a non-reference const, we mark the numpy array readonly:
97- static handle cast (const Type &&src, return_value_policy /* policy */ , handle parent) {
98- return cast_impl (&src, return_value_policy::move, parent);
99- }
100- // lvalue reference return; default (automatic) becomes copy
101- static handle cast (Type &src, return_value_policy policy, handle parent) {
102- if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
103- policy = return_value_policy::copy;
104- return cast_impl (&src, policy, parent);
105- }
106- // const lvalue reference return; default (automatic) becomes copy
107- static handle cast (const Type &src, return_value_policy policy, handle parent) {
108- if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
109- policy = return_value_policy::copy;
110- return cast (&src, policy, parent);
111- }
112- // non-const pointer return
113- static handle cast (Type *src, return_value_policy policy, handle parent) {
114- return cast_impl (src, policy, parent);
115- }
116- // const pointer return
117- static handle cast (const Type *src, return_value_policy policy, handle parent) {
118- return cast_impl (src, policy, parent);
119- }
120-
121- static PYBIND11_DESCR name () { return _ (" xt::xtensor" ); }
122- template <typename T> using cast_op_type = movable_cast_op_type<T>;
123- };
124-
125- NAMESPACE_END (detail)
126- NAMESPACE_END(PYBIND11_NAMESPACE)
157+ }
127158
128159#endif
0 commit comments