Skip to content

Commit 3799515

Browse files
author
Yoshiaki Bando
committed
Format xtensor_type_caster_base.hpp, pytensor.hpp, and pyarray.hpp
1 parent c67248e commit 3799515

File tree

3 files changed

+161
-114
lines changed

3 files changed

+161
-114
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,25 @@ namespace pybind11
7373

7474
// Type caster for casting ndarray to xexpression<pyarray>
7575
template<typename T>
76-
struct type_caster<xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>{
77-
public:
76+
struct type_caster<xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>
77+
{
7878
using Type = xt::xexpression<xt::pyarray<T>>;
7979

80-
operator Type&() { return this->value; }
81-
operator const Type&() { return this->value; }
80+
operator Type&()
81+
{
82+
return this->value;
83+
}
84+
85+
operator const Type&()
86+
{
87+
return this->value;
88+
}
8289
};
8390

8491
// Type caster for casting xarray to ndarray
8592
template<class T>
86-
struct type_caster<xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>> {
93+
struct type_caster<xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>>
94+
{
8795
};
8896
}
8997
}

include/xtensor-python/pytensor.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,25 @@ namespace pybind11
7575

7676
// Type caster for casting ndarray to xexpression<pytensor>
7777
template<class T, std::size_t N>
78-
struct type_caster<xt::xexpression<xt::pytensor<T, N>>> : pyobject_caster<xt::pytensor<T, N>>{
79-
public:
78+
struct type_caster<xt::xexpression<xt::pytensor<T, N>>> : pyobject_caster<xt::pytensor<T, N>>
79+
{
8080
using Type = xt::xexpression<xt::pytensor<T, N>>;
8181

82-
operator Type&() { return this->value; }
83-
operator const Type&() { return this->value; }
82+
operator Type&()
83+
{
84+
return this->value;
85+
}
86+
87+
operator const Type&()
88+
{
89+
return this->value;
90+
}
8491
};
8592

8693
// Type caster for casting xt::xtensor to ndarray
8794
template<class T, std::size_t N>
88-
struct type_caster<xt::xtensor<T, N>> : xtensor_type_caster_base<xt::xtensor<T, N>> {
95+
struct type_caster<xt::xtensor<T, N>> : xtensor_type_caster_base<xt::xtensor<T, N>>
96+
{
8997
};
9098
}
9199
}
Lines changed: 135 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
xtensor-python/xtensor_type_caster.hpp: Transparent conversion for xtensor and xarray
33
4-
This code was inspired by the following code written by Wenzei Jakob
4+
This code is based on the following code written by Wenzei Jakob
55
66
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
77
@@ -19,110 +19,141 @@
1919
#include <pybind11/pybind11.h>
2020
#include <pybind11/numpy.h>
2121

22-
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
23-
NAMESPACE_BEGIN(detail)
24-
25-
26-
// Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
27-
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
28-
template<typename Type>
29-
handle xtensor_array_cast(Type const &src, handle base = handle(), bool writeable = true) {
30-
std::vector<ssize_t> python_strides(src.strides().size());
31-
std::transform(src.strides().begin(), src.strides().end(), python_strides.data(),
32-
[](auto v) { return sizeof(typename Type::value_type) * v; });
33-
34-
array a(src.shape(), python_strides, src.begin(), base);
35-
36-
if (!writeable)
37-
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
38-
39-
return a.release();
40-
}
41-
42-
43-
// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
44-
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
45-
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
46-
// non-writeable if the given type is const.
47-
template <typename Type, typename CType>
48-
handle xtensor_ref_array(CType &src, handle parent = none()) {
49-
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
50-
}
51-
52-
53-
// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
54-
// array that references the encapsulated data with a python-side reference to the capsule to tie
55-
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
56-
// not the CType of the pointer given is const.
57-
template <typename Type, typename CType>
58-
handle xtensor_encapsulate(CType *src) {
59-
capsule base(src, [](void *o) { delete static_cast<CType *>(o); });
60-
return xtensor_ref_array<Type>(*src, base);
61-
}
62-
63-
// Base class of type_caster for xtensor and xarray
64-
template<class Type>
65-
struct xtensor_type_caster_base {
66-
bool load(handle src, bool) {
67-
return false;
68-
}
69-
70-
private:
71-
// Cast implementation
72-
template <typename CType>
73-
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
74-
switch (policy) {
75-
case return_value_policy::take_ownership:
76-
case return_value_policy::automatic:
77-
return xtensor_encapsulate<Type>(src);
78-
case return_value_policy::move:
79-
return xtensor_encapsulate<Type>(new CType(std::move(*src)));
80-
case return_value_policy::copy:
81-
return xtensor_array_cast<Type>(*src);
82-
case return_value_policy::reference:
83-
case return_value_policy::automatic_reference:
84-
return xtensor_ref_array<Type>(*src);
85-
case return_value_policy::reference_internal:
86-
return xtensor_ref_array<Type>(*src, parent);
87-
default:
88-
throw cast_error("unhandled return_value_policy: should not happen!");
22+
namespace pybind11
23+
{
24+
namespace detail
25+
{
26+
// Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
27+
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
28+
template<typename Type>
29+
handle xtensor_array_cast(Type const &src, handle base = handle(), bool writeable = true)
30+
{
31+
std::vector<ssize_t> python_strides(src.strides().size());
32+
std::transform(src.strides().begin(), src.strides().end(), python_strides.data(),
33+
[](auto v) { return sizeof(typename Type::value_type) * v; });
34+
35+
array a(src.shape(), python_strides, src.begin(), base);
36+
37+
if (!writeable)
38+
{
39+
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
40+
}
41+
42+
return a.release();
43+
}
44+
45+
// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
46+
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
47+
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
48+
// non-writeable if the given type is const.
49+
template <typename Type, typename CType>
50+
handle xtensor_ref_array(CType &src, handle parent = none())
51+
{
52+
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
53+
}
54+
55+
// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
56+
// array that references the encapsulated data with a python-side reference to the capsule to tie
57+
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
58+
// not the CType of the pointer given is const.
59+
template <typename Type, typename CType>
60+
handle xtensor_encapsulate(CType *src)
61+
{
62+
capsule base(src, [](void *o) { delete static_cast<CType *>(o); });
63+
return xtensor_ref_array<Type>(*src, base);
64+
}
65+
66+
// Base class of type_caster for xtensor and xarray
67+
template<class Type>
68+
struct xtensor_type_caster_base
69+
{
70+
bool load(handle src, bool)
71+
{
72+
return false;
73+
}
74+
75+
private:
76+
77+
// Cast implementation
78+
template <typename CType>
79+
static handle cast_impl(CType *src, return_value_policy policy, handle parent)
80+
{
81+
switch (policy)
82+
{
83+
case return_value_policy::take_ownership:
84+
case return_value_policy::automatic:
85+
return xtensor_encapsulate<Type>(src);
86+
case return_value_policy::move:
87+
return xtensor_encapsulate<Type>(new CType(std::move(*src)));
88+
case return_value_policy::copy:
89+
return xtensor_array_cast<Type>(*src);
90+
case return_value_policy::reference:
91+
case return_value_policy::automatic_reference:
92+
return xtensor_ref_array<Type>(*src);
93+
case return_value_policy::reference_internal:
94+
return xtensor_ref_array<Type>(*src, parent);
95+
default:
96+
throw cast_error("unhandled return_value_policy: should not happen!");
97+
};
98+
}
99+
100+
public:
101+
102+
// Normal returned non-reference, non-const value:
103+
static handle cast(Type &&src, return_value_policy /* policy */, handle parent)
104+
{
105+
return cast_impl(&src, return_value_policy::move, parent);
106+
}
107+
108+
// If you return a non-reference const, we mark the numpy array readonly:
109+
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent)
110+
{
111+
return cast_impl(&src, return_value_policy::move, parent);
112+
}
113+
114+
// lvalue reference return; default (automatic) becomes copy
115+
static handle cast(Type &src, return_value_policy policy, handle parent)
116+
{
117+
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
118+
{
119+
policy = return_value_policy::copy;
120+
}
121+
122+
return cast_impl(&src, policy, parent);
123+
}
124+
125+
// const lvalue reference return; default (automatic) becomes copy
126+
static handle cast(const Type &src, return_value_policy policy, handle parent)
127+
{
128+
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
129+
{
130+
policy = return_value_policy::copy;
131+
}
132+
133+
return cast(&src, policy, parent);
134+
}
135+
136+
// non-const pointer return
137+
static handle cast(Type *src, return_value_policy policy, handle parent)
138+
{
139+
return cast_impl(src, policy, parent);
140+
}
141+
142+
// const pointer return
143+
static handle cast(const Type *src, return_value_policy policy, handle parent)
144+
{
145+
return cast_impl(src, policy, parent);
146+
}
147+
148+
static PYBIND11_DESCR name()
149+
{
150+
return _("xt::xtensor");
151+
}
152+
153+
template <typename T>
154+
using cast_op_type = movable_cast_op_type<T>;
89155
};
90156
}
91-
public:
92-
// Normal returned non-reference, non-const value:
93-
static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
94-
return cast_impl(&src, return_value_policy::move, parent);
95-
}
96-
// If you return a non-reference const, we mark the numpy array readonly:
97-
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
98-
return cast_impl(&src, return_value_policy::move, parent);
99-
}
100-
// lvalue reference return; default (automatic) becomes copy
101-
static handle cast(Type &src, return_value_policy policy, handle parent) {
102-
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
103-
policy = return_value_policy::copy;
104-
return cast_impl(&src, policy, parent);
105-
}
106-
// const lvalue reference return; default (automatic) becomes copy
107-
static handle cast(const Type &src, return_value_policy policy, handle parent) {
108-
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
109-
policy = return_value_policy::copy;
110-
return cast(&src, policy, parent);
111-
}
112-
// non-const pointer return
113-
static handle cast(Type *src, return_value_policy policy, handle parent) {
114-
return cast_impl(src, policy, parent);
115-
}
116-
// const pointer return
117-
static handle cast(const Type *src, return_value_policy policy, handle parent) {
118-
return cast_impl(src, policy, parent);
119-
}
120-
121-
static PYBIND11_DESCR name() { return _("xt::xtensor"); }
122-
template <typename T> using cast_op_type = movable_cast_op_type<T>;
123-
};
124-
125-
NAMESPACE_END(detail)
126-
NAMESPACE_END(PYBIND11_NAMESPACE)
157+
}
127158

128159
#endif

0 commit comments

Comments
 (0)