Skip to content

Commit aa49055

Browse files
authored
Merge pull request #124 from JohanMabille/int64
overload fix
2 parents 71c9161 + 01a7e74 commit aa49055

File tree

5 files changed

+97
-3
lines changed

5 files changed

+97
-3
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "pystrides_adaptor.hpp"
2222
#include "xtensor_type_caster_base.hpp"
2323

24+
#include <iostream>
25+
2426
namespace xt
2527
{
2628
template <class T>
@@ -54,7 +56,7 @@ namespace pybind11
5456
return false;
5557
}
5658
int type_num = xt::detail::numpy_traits<T>::type_num;
57-
if (PyArray_TYPE(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
59+
if(xt::detail::pyarray_type(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
5860
{
5961
return false;
6062
}

include/xtensor-python/pycontainer.hpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,16 @@ namespace xt
125125
{
126126
private:
127127

128+
// On Windows 64 bits, NPY_INT != NPY_INT32 and NPY_UINT != NPY_UINT32
129+
// We use the NPY_INT32 and NPY_UINT32 which are consistent with the values
130+
// of NPY_LONG and NPY_ULONG
131+
// On Linux x64, NPY_INT64 != NPY_LONGLONG and NPY_UINT64 != NPY_ULONGLONG,
132+
// we use the values of NPY_INT64 and NPY_UINT64 which are consistent with the
133+
// values of NPY_LONG and NPY_ULONG.
128134
constexpr static const int value_list[15] = {
129135
NPY_BOOL,
130136
NPY_BYTE, NPY_UBYTE, NPY_SHORT, NPY_USHORT,
131-
NPY_INT, NPY_UINT, NPY_LONGLONG, NPY_ULONGLONG,
137+
NPY_INT32, NPY_UINT32, NPY_INT64, NPY_UINT64,
132138
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
133139
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE};
134140

@@ -138,6 +144,37 @@ namespace xt
138144

139145
static constexpr int type_num = value_list[pybind11::detail::is_fmt_numeric<value_type>::index];
140146
};
147+
148+
// On Linux x64, NPY_INT64 != NPY_LONGLONG and NPY_UINT64 != NPY_ULONGLONG
149+
// NPY_LONGLONG and NPY_ULONGLONG must be adjusted so the right type is
150+
// selected
151+
template <bool>
152+
struct numpy_enum_adjuster
153+
{
154+
static inline int pyarray_type(PyArrayObject* obj)
155+
{
156+
return PyArray_TYPE(obj);
157+
}
158+
};
159+
160+
template <>
161+
struct numpy_enum_adjuster<true>
162+
{
163+
static inline int pyarray_type(PyArrayObject* obj)
164+
{
165+
int res = PyArray_TYPE(obj);
166+
if(res == NPY_LONGLONG || res == NPY_ULONGLONG)
167+
{
168+
res -= 2;
169+
}
170+
return res;
171+
}
172+
};
173+
174+
inline int pyarray_type(PyArrayObject* obj)
175+
{
176+
return numpy_enum_adjuster<NPY_LONGLONG != NPY_INT64>::pyarray_type(obj);
177+
}
141178
}
142179

143180
/******************************

include/xtensor-python/pytensor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ namespace pybind11
5555
return false;
5656
}
5757
int type_num = xt::detail::numpy_traits<T>::type_num;
58-
if (PyArray_TYPE(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
58+
if(xt::detail::pyarray_type(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
5959
{
6060
return false;
6161
}

test_python/main.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,45 @@ int add(int i, int j)
6868
return i + j;
6969
}
7070

71+
template <class T> std::string typestring() { return "Unknown"; }
72+
template <> std::string typestring<uint8_t>() { return "uint8"; }
73+
template <> std::string typestring<int8_t>() { return "int8"; }
74+
template <> std::string typestring<uint16_t>() { return "uint16"; }
75+
template <> std::string typestring<int16_t>() { return "int16"; }
76+
template <> std::string typestring<uint32_t>() { return "uint32"; }
77+
template <> std::string typestring<int32_t>() { return "int32"; }
78+
template <> std::string typestring<uint64_t>() { return "uint64"; }
79+
template <> std::string typestring<int64_t>() { return "int64"; }
80+
81+
template <class T>
82+
inline std::string int_overload(xt::pyarray<T>& m)
83+
{
84+
return typestring<T>();
85+
}
86+
87+
void dump_numpy_constant()
88+
{
89+
std::cout << "NPY_BOOL = " << NPY_BOOL << std::endl;
90+
std::cout << "NPY_BYTE = " << NPY_BYTE << std::endl;
91+
std::cout << "NPY_UBYTE = " << NPY_UBYTE << std::endl;
92+
std::cout << "NPY_INT8 = " << NPY_INT8 << std::endl;
93+
std::cout << "NPY_UINT8 = " << NPY_UINT8 << std::endl;
94+
std::cout << "NPY_SHORT = " << NPY_SHORT << std::endl;
95+
std::cout << "NPY_USHORT = " << NPY_USHORT << std::endl;
96+
std::cout << "NPY_INT16 = " << NPY_INT16 << std::endl;
97+
std::cout << "NPY_UINT16 = " << NPY_UINT16 << std::endl;
98+
std::cout << "NPY_INT = " << NPY_INT << std::endl;
99+
std::cout << "NPY_UINT = " << NPY_UINT << std::endl;
100+
std::cout << "NPY_INT32 = " << NPY_INT32 << std::endl;
101+
std::cout << "NPY_UINT32 = " << NPY_UINT32 << std::endl;
102+
std::cout << "NPY_LONG = " << NPY_LONG << std::endl;
103+
std::cout << "NPY_ULONG = " << NPY_ULONG << std::endl;
104+
std::cout << "NPY_LONGLONG = " << NPY_LONGLONG << std::endl;
105+
std::cout << "NPY_ULONGLONG = " << NPY_ULONGLONG << std::endl;
106+
std::cout << "NPY_INT64 = " << NPY_INT64 << std::endl;
107+
std::cout << "NPY_UINT64 = " << NPY_UINT64 << std::endl;
108+
}
109+
71110
PYBIND11_PLUGIN(xtensor_python_test)
72111
{
73112
xt::import_numpy();
@@ -93,5 +132,16 @@ PYBIND11_PLUGIN(xtensor_python_test)
93132
return a.shape() == b.shape();
94133
});
95134

135+
m.def("int_overload", int_overload<uint8_t>);
136+
m.def("int_overload", int_overload<int8_t>);
137+
m.def("int_overload", int_overload<uint16_t>);
138+
m.def("int_overload", int_overload<int16_t>);
139+
m.def("int_overload", int_overload<uint32_t>);
140+
m.def("int_overload", int_overload<int32_t>);
141+
m.def("int_overload", int_overload<uint64_t>);
142+
m.def("int_overload", int_overload<int64_t>);
143+
144+
m.def("dump_numpy_constant", dump_numpy_constant);
145+
96146
return m.ptr();
97147
}

test_python/test_pyarray.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,8 @@ def test_shape_comparison(self):
8282
self.assertFalse(xt.compare_shapes(x, y))
8383
self.assertTrue(xt.compare_shapes(x, z))
8484

85+
def test_int_overload(self):
86+
for dtype in [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64]:
87+
b = xt.int_overload(np.ones((10), dtype))
88+
self.assertEqual(str(dtype.__name__), b)
89+

0 commit comments

Comments
 (0)