Skip to content

Commit c67248e

Browse files
author
Yoshiaki Bando
committed
Add type_caster classes.
They enable casing ndarray to xexpression<pyarray> and casing xarray to ndarray
1 parent ca42d8b commit c67248e

File tree

4 files changed

+161
-0
lines changed

4 files changed

+161
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ set(XTENSOR_PYTHON_HEADERS
3838
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
3939
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
4040
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
41+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
4142
)
4243

4344
OPTION(BUILD_TESTS "xtensor test suite" OFF)

include/xtensor-python/pyarray.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "pycontainer.hpp"
2121
#include "pystrides_adaptor.hpp"
22+
#include "xtensor_type_caster_base.hpp"
2223

2324
namespace xt
2425
{
@@ -69,6 +70,21 @@ namespace pybind11
6970

7071
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
7172
};
73+
74+
// Type caster for casting ndarray to xexpression<pyarray>
75+
template<typename T>
76+
struct type_caster<xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>{
77+
public:
78+
using Type = xt::xexpression<xt::pyarray<T>>;
79+
80+
operator Type&() { return this->value; }
81+
operator const Type&() { return this->value; }
82+
};
83+
84+
// Type caster for casting xarray to ndarray
85+
template<class T>
86+
struct type_caster<xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>> {
87+
};
7288
}
7389
}
7490

include/xtensor-python/pytensor.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "pycontainer.hpp"
2222
#include "pystrides_adaptor.hpp"
23+
#include "xtensor_type_caster_base.hpp"
2324

2425
namespace xt
2526
{
@@ -71,6 +72,21 @@ namespace pybind11
7172

7273
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
7374
};
75+
76+
// Type caster for casting ndarray to xexpression<pytensor>
77+
template<class T, std::size_t N>
78+
struct type_caster<xt::xexpression<xt::pytensor<T, N>>> : pyobject_caster<xt::pytensor<T, N>>{
79+
public:
80+
using Type = xt::xexpression<xt::pytensor<T, N>>;
81+
82+
operator Type&() { return this->value; }
83+
operator const Type&() { return this->value; }
84+
};
85+
86+
// Type caster for casting xt::xtensor to ndarray
87+
template<class T, std::size_t N>
88+
struct type_caster<xt::xtensor<T, N>> : xtensor_type_caster_base<xt::xtensor<T, N>> {
89+
};
7490
}
7591
}
7692

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
xtensor-python/xtensor_type_caster.hpp: Transparent conversion for xtensor and xarray
3+
4+
This code was inspired by the following code written by Wenzei Jakob
5+
6+
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
7+
8+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
9+
10+
All rights reserved. Use of this source code is governed by a
11+
BSD-style license that can be found in the LICENSE file.
12+
*/
13+
14+
15+
#ifndef XTENSOR_TYPE_CASTER_HPP
16+
#define XTENSOR_TYPE_CASTER_HPP
17+
18+
#include "xtensor/xtensor.hpp"
19+
#include <pybind11/pybind11.h>
20+
#include <pybind11/numpy.h>
21+
22+
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
23+
NAMESPACE_BEGIN(detail)
24+
25+
26+
// Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
27+
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
28+
template<typename Type>
29+
handle xtensor_array_cast(Type const &src, handle base = handle(), bool writeable = true) {
30+
std::vector<ssize_t> python_strides(src.strides().size());
31+
std::transform(src.strides().begin(), src.strides().end(), python_strides.data(),
32+
[](auto v) { return sizeof(typename Type::value_type) * v; });
33+
34+
array a(src.shape(), python_strides, src.begin(), base);
35+
36+
if (!writeable)
37+
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
38+
39+
return a.release();
40+
}
41+
42+
43+
// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
44+
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
45+
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
46+
// non-writeable if the given type is const.
47+
template <typename Type, typename CType>
48+
handle xtensor_ref_array(CType &src, handle parent = none()) {
49+
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
50+
}
51+
52+
53+
// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
54+
// array that references the encapsulated data with a python-side reference to the capsule to tie
55+
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
56+
// not the CType of the pointer given is const.
57+
template <typename Type, typename CType>
58+
handle xtensor_encapsulate(CType *src) {
59+
capsule base(src, [](void *o) { delete static_cast<CType *>(o); });
60+
return xtensor_ref_array<Type>(*src, base);
61+
}
62+
63+
// Base class of type_caster for xtensor and xarray
64+
template<class Type>
65+
struct xtensor_type_caster_base {
66+
bool load(handle src, bool) {
67+
return false;
68+
}
69+
70+
private:
71+
// Cast implementation
72+
template <typename CType>
73+
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
74+
switch (policy) {
75+
case return_value_policy::take_ownership:
76+
case return_value_policy::automatic:
77+
return xtensor_encapsulate<Type>(src);
78+
case return_value_policy::move:
79+
return xtensor_encapsulate<Type>(new CType(std::move(*src)));
80+
case return_value_policy::copy:
81+
return xtensor_array_cast<Type>(*src);
82+
case return_value_policy::reference:
83+
case return_value_policy::automatic_reference:
84+
return xtensor_ref_array<Type>(*src);
85+
case return_value_policy::reference_internal:
86+
return xtensor_ref_array<Type>(*src, parent);
87+
default:
88+
throw cast_error("unhandled return_value_policy: should not happen!");
89+
};
90+
}
91+
public:
92+
// Normal returned non-reference, non-const value:
93+
static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
94+
return cast_impl(&src, return_value_policy::move, parent);
95+
}
96+
// If you return a non-reference const, we mark the numpy array readonly:
97+
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
98+
return cast_impl(&src, return_value_policy::move, parent);
99+
}
100+
// lvalue reference return; default (automatic) becomes copy
101+
static handle cast(Type &src, return_value_policy policy, handle parent) {
102+
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
103+
policy = return_value_policy::copy;
104+
return cast_impl(&src, policy, parent);
105+
}
106+
// const lvalue reference return; default (automatic) becomes copy
107+
static handle cast(const Type &src, return_value_policy policy, handle parent) {
108+
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
109+
policy = return_value_policy::copy;
110+
return cast(&src, policy, parent);
111+
}
112+
// non-const pointer return
113+
static handle cast(Type *src, return_value_policy policy, handle parent) {
114+
return cast_impl(src, policy, parent);
115+
}
116+
// const pointer return
117+
static handle cast(const Type *src, return_value_policy policy, handle parent) {
118+
return cast_impl(src, policy, parent);
119+
}
120+
121+
static PYBIND11_DESCR name() { return _("xt::xtensor"); }
122+
template <typename T> using cast_op_type = movable_cast_op_type<T>;
123+
};
124+
125+
NAMESPACE_END(detail)
126+
NAMESPACE_END(PYBIND11_NAMESPACE)
127+
128+
#endif

0 commit comments

Comments
 (0)