Skip to content

Commit 2562449

Browse files
Merge pull request #113 from yoshipon/xtensor_type_caster
Add type_caster classes for xt::xtensor, xt::xarray, and xt::xexpression.
2 parents ca42d8b + d4f6d7c commit 2562449

File tree

5 files changed

+212
-1
lines changed

5 files changed

+212
-1
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ matrix:
5959
env: PY=3
6060
env:
6161
global:
62-
- MINCONDA_VERSION="latest"
62+
- MINCONDA_VERSION="4.3.21"
6363
- MINCONDA_LINUX="Linux-x86_64"
6464
- MINCONDA_OSX="MacOSX-x86_64"
6565
before_install:

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ set(XTENSOR_PYTHON_HEADERS
3838
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
3939
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
4040
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
41+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
4142
)
4243

4344
OPTION(BUILD_TESTS "xtensor test suite" OFF)

include/xtensor-python/pyarray.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "pycontainer.hpp"
2121
#include "pystrides_adaptor.hpp"
22+
#include "xtensor_type_caster_base.hpp"
2223

2324
namespace xt
2425
{
@@ -69,6 +70,29 @@ namespace pybind11
6970

7071
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
7172
};
73+
74+
// Type caster for casting ndarray to xexpression<pyarray>
75+
template<typename T>
76+
struct type_caster<xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>
77+
{
78+
using Type = xt::xexpression<xt::pyarray<T>>;
79+
80+
operator Type&()
81+
{
82+
return this->value;
83+
}
84+
85+
operator const Type&()
86+
{
87+
return this->value;
88+
}
89+
};
90+
91+
// Type caster for casting xarray to ndarray
92+
template<class T>
93+
struct type_caster<xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>>
94+
{
95+
};
7296
}
7397
}
7498

include/xtensor-python/pytensor.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "pycontainer.hpp"
2222
#include "pystrides_adaptor.hpp"
23+
#include "xtensor_type_caster_base.hpp"
2324

2425
namespace xt
2526
{
@@ -71,6 +72,29 @@ namespace pybind11
7172

7273
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
7374
};
75+
76+
// Type caster for casting ndarray to xexpression<pytensor>
77+
template<class T, std::size_t N>
78+
struct type_caster<xt::xexpression<xt::pytensor<T, N>>> : pyobject_caster<xt::pytensor<T, N>>
79+
{
80+
using Type = xt::xexpression<xt::pytensor<T, N>>;
81+
82+
operator Type&()
83+
{
84+
return this->value;
85+
}
86+
87+
operator const Type&()
88+
{
89+
return this->value;
90+
}
91+
};
92+
93+
// Type caster for casting xt::xtensor to ndarray
94+
template<class T, std::size_t N>
95+
struct type_caster<xt::xtensor<T, N>> : xtensor_type_caster_base<xt::xtensor<T, N>>
96+
{
97+
};
7498
}
7599
}
76100

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
xtensor-python/xtensor_type_caster.hpp: Transparent conversion for xtensor and xarray
3+
4+
This code is based on the following code written by Wenzei Jakob
5+
6+
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
7+
8+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
9+
10+
All rights reserved. Use of this source code is governed by a
11+
BSD-style license that can be found in the LICENSE file.
12+
*/
13+
14+
15+
#ifndef XTENSOR_TYPE_CASTER_HPP
16+
#define XTENSOR_TYPE_CASTER_HPP
17+
18+
#include "xtensor/xtensor.hpp"
19+
#include <pybind11/pybind11.h>
20+
#include <pybind11/numpy.h>
21+
22+
namespace pybind11
23+
{
24+
namespace detail
25+
{
26+
// Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
27+
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
28+
template<typename Type>
29+
handle xtensor_array_cast(Type const &src, handle base = handle(), bool writeable = true)
30+
{
31+
std::vector<size_t> python_strides(src.strides().size());
32+
std::transform(src.strides().begin(), src.strides().end(), python_strides.begin(),
33+
[](auto v) { return sizeof(typename Type::value_type) * v; });
34+
35+
std::vector<size_t> python_shape(src.shape().size());
36+
std::copy(src.shape().begin(), src.shape().end(), python_shape.begin());
37+
38+
array a(python_shape, python_strides, src.begin(), base);
39+
40+
if (!writeable)
41+
{
42+
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
43+
}
44+
45+
return a.release();
46+
}
47+
48+
// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
49+
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
50+
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
51+
// non-writeable if the given type is const.
52+
template <typename Type, typename CType>
53+
handle xtensor_ref_array(CType &src, handle parent = none())
54+
{
55+
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
56+
}
57+
58+
// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
59+
// array that references the encapsulated data with a python-side reference to the capsule to tie
60+
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
61+
// not the CType of the pointer given is const.
62+
template <typename Type, typename CType>
63+
handle xtensor_encapsulate(CType *src)
64+
{
65+
capsule base(src, [](void *o) { delete static_cast<CType *>(o); });
66+
return xtensor_ref_array<Type>(*src, base);
67+
}
68+
69+
// Base class of type_caster for xtensor and xarray
70+
template<class Type>
71+
struct xtensor_type_caster_base
72+
{
73+
bool load(handle src, bool)
74+
{
75+
return false;
76+
}
77+
78+
private:
79+
80+
// Cast implementation
81+
template <typename CType>
82+
static handle cast_impl(CType *src, return_value_policy policy, handle parent)
83+
{
84+
switch (policy)
85+
{
86+
case return_value_policy::take_ownership:
87+
case return_value_policy::automatic:
88+
return xtensor_encapsulate<Type>(src);
89+
case return_value_policy::move:
90+
return xtensor_encapsulate<Type>(new CType(std::move(*src)));
91+
case return_value_policy::copy:
92+
return xtensor_array_cast<Type>(*src);
93+
case return_value_policy::reference:
94+
case return_value_policy::automatic_reference:
95+
return xtensor_ref_array<Type>(*src);
96+
case return_value_policy::reference_internal:
97+
return xtensor_ref_array<Type>(*src, parent);
98+
default:
99+
throw cast_error("unhandled return_value_policy: should not happen!");
100+
};
101+
}
102+
103+
public:
104+
105+
// Normal returned non-reference, non-const value:
106+
static handle cast(Type &&src, return_value_policy /* policy */, handle parent)
107+
{
108+
return cast_impl(&src, return_value_policy::move, parent);
109+
}
110+
111+
// If you return a non-reference const, we mark the numpy array readonly:
112+
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent)
113+
{
114+
return cast_impl(&src, return_value_policy::move, parent);
115+
}
116+
117+
// lvalue reference return; default (automatic) becomes copy
118+
static handle cast(Type &src, return_value_policy policy, handle parent)
119+
{
120+
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
121+
{
122+
policy = return_value_policy::copy;
123+
}
124+
125+
return cast_impl(&src, policy, parent);
126+
}
127+
128+
// const lvalue reference return; default (automatic) becomes copy
129+
static handle cast(const Type &src, return_value_policy policy, handle parent)
130+
{
131+
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
132+
{
133+
policy = return_value_policy::copy;
134+
}
135+
136+
return cast(&src, policy, parent);
137+
}
138+
139+
// non-const pointer return
140+
static handle cast(Type *src, return_value_policy policy, handle parent)
141+
{
142+
return cast_impl(src, policy, parent);
143+
}
144+
145+
// const pointer return
146+
static handle cast(const Type *src, return_value_policy policy, handle parent)
147+
{
148+
return cast_impl(src, policy, parent);
149+
}
150+
151+
static PYBIND11_DESCR name()
152+
{
153+
return _("xt::xtensor");
154+
}
155+
156+
template <typename T>
157+
using cast_op_type = cast_op_type<T>;
158+
};
159+
}
160+
}
161+
162+
#endif

0 commit comments

Comments
 (0)