1+ /*
2+ xtensor-python/xtensor_type_caster.hpp: Transparent conversion for xtensor and xarray
3+
4+ This code was inspired by the following code written by Wenzei Jakob
5+
6+ pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
7+
8+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
9+
10+ All rights reserved. Use of this source code is governed by a
11+ BSD-style license that can be found in the LICENSE file.
12+ */
13+
14+
15+ #ifndef XTENSOR_TYPE_CASTER_HPP
16+ #define XTENSOR_TYPE_CASTER_HPP
17+
18+ #include " xtensor/xtensor.hpp"
19+ #include < pybind11/pybind11.h>
20+ #include < pybind11/numpy.h>
21+
22+ NAMESPACE_BEGIN (PYBIND11_NAMESPACE)
23+ NAMESPACE_BEGIN(detail)
24+
25+
26+ // Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
27+ // otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
28+ template<typename Type>
29+ handle xtensor_array_cast(Type const &src, handle base = handle(), bool writeable = true) {
30+ std::vector<ssize_t > python_strides (src.strides ().size ());
31+ std::transform (src.strides ().begin (), src.strides ().end (), python_strides.data (),
32+ [](auto v) { return sizeof (typename Type::value_type) * v; });
33+
34+ array a (src.shape (), python_strides, src.begin (), base);
35+
36+ if (!writeable)
37+ array_proxy (a.ptr ())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
38+
39+ return a.release ();
40+ }
41+
42+
43+ // Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
44+ // reference the xtensor object's data with `base` as the python-registered base class (if omitted,
45+ // the base will be set to None, and lifetime management is up to the caller). The numpy array is
46+ // non-writeable if the given type is const.
47+ template <typename Type, typename CType>
48+ handle xtensor_ref_array (CType &src, handle parent = none()) {
49+ return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
50+ }
51+
52+
53+ // Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
54+ // array that references the encapsulated data with a python-side reference to the capsule to tie
55+ // its destruction to that of any dependent python objects. Const-ness is determined by whether or
56+ // not the CType of the pointer given is const.
57+ template <typename Type, typename CType>
58+ handle xtensor_encapsulate (CType *src) {
59+ capsule base (src, [](void *o) { delete static_cast <CType *>(o); });
60+ return xtensor_ref_array<Type>(*src, base);
61+ }
62+
63+ // Base class of type_caster for xtensor and xarray
64+ template <class Type >
65+ struct xtensor_type_caster_base {
66+ bool load (handle src, bool ) {
67+ return false ;
68+ }
69+
70+ private:
71+ // Cast implementation
72+ template <typename CType>
73+ static handle cast_impl (CType *src, return_value_policy policy, handle parent) {
74+ switch (policy) {
75+ case return_value_policy::take_ownership:
76+ case return_value_policy::automatic:
77+ return xtensor_encapsulate<Type>(src);
78+ case return_value_policy::move:
79+ return xtensor_encapsulate<Type>(new CType (std::move (*src)));
80+ case return_value_policy::copy:
81+ return xtensor_array_cast<Type>(*src);
82+ case return_value_policy::reference:
83+ case return_value_policy::automatic_reference:
84+ return xtensor_ref_array<Type>(*src);
85+ case return_value_policy::reference_internal:
86+ return xtensor_ref_array<Type>(*src, parent);
87+ default :
88+ throw cast_error (" unhandled return_value_policy: should not happen!" );
89+ };
90+ }
91+ public:
92+ // Normal returned non-reference, non-const value:
93+ static handle cast (Type &&src, return_value_policy /* policy */ , handle parent) {
94+ return cast_impl (&src, return_value_policy::move, parent);
95+ }
96+ // If you return a non-reference const, we mark the numpy array readonly:
97+ static handle cast (const Type &&src, return_value_policy /* policy */ , handle parent) {
98+ return cast_impl (&src, return_value_policy::move, parent);
99+ }
100+ // lvalue reference return; default (automatic) becomes copy
101+ static handle cast (Type &src, return_value_policy policy, handle parent) {
102+ if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
103+ policy = return_value_policy::copy;
104+ return cast_impl (&src, policy, parent);
105+ }
106+ // const lvalue reference return; default (automatic) becomes copy
107+ static handle cast (const Type &src, return_value_policy policy, handle parent) {
108+ if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
109+ policy = return_value_policy::copy;
110+ return cast (&src, policy, parent);
111+ }
112+ // non-const pointer return
113+ static handle cast (Type *src, return_value_policy policy, handle parent) {
114+ return cast_impl (src, policy, parent);
115+ }
116+ // const pointer return
117+ static handle cast (const Type *src, return_value_policy policy, handle parent) {
118+ return cast_impl (src, policy, parent);
119+ }
120+
121+ static PYBIND11_DESCR name () { return _ (" xt::xtensor" ); }
122+ template <typename T> using cast_op_type = movable_cast_op_type<T>;
123+ };
124+
125+ NAMESPACE_END (detail)
126+ NAMESPACE_END(PYBIND11_NAMESPACE)
127+
128+ #endif
0 commit comments