@@ -68,6 +68,45 @@ int add(int i, int j)
6868 return i + j;
6969}
7070
71+ template <class T > std::string typestring () { return " Unknown" ; }
72+ template <> std::string typestring<uint8_t >() { return " uint8" ; }
73+ template <> std::string typestring<int8_t >() { return " int8" ; }
74+ template <> std::string typestring<uint16_t >() { return " uint16" ; }
75+ template <> std::string typestring<int16_t >() { return " int16" ; }
76+ template <> std::string typestring<uint32_t >() { return " uint32" ; }
77+ template <> std::string typestring<int32_t >() { return " int32" ; }
78+ template <> std::string typestring<uint64_t >() { return " uint64" ; }
79+ template <> std::string typestring<int64_t >() { return " int64" ; }
80+
81+ template <class T >
82+ inline std::string int_overload (xt::pyarray<T>& m)
83+ {
84+ return typestring<T>();
85+ }
86+
87+ void dump_numpy_constant ()
88+ {
89+ std::cout << " NPY_BOOL = " << NPY_BOOL << std::endl;
90+ std::cout << " NPY_BYTE = " << NPY_BYTE << std::endl;
91+ std::cout << " NPY_UBYTE = " << NPY_UBYTE << std::endl;
92+ std::cout << " NPY_INT8 = " << NPY_INT8 << std::endl;
93+ std::cout << " NPY_UINT8 = " << NPY_UINT8 << std::endl;
94+ std::cout << " NPY_SHORT = " << NPY_SHORT << std::endl;
95+ std::cout << " NPY_USHORT = " << NPY_USHORT << std::endl;
96+ std::cout << " NPY_INT16 = " << NPY_INT16 << std::endl;
97+ std::cout << " NPY_UINT16 = " << NPY_UINT16 << std::endl;
98+ std::cout << " NPY_INT = " << NPY_INT << std::endl;
99+ std::cout << " NPY_UINT = " << NPY_UINT << std::endl;
100+ std::cout << " NPY_INT32 = " << NPY_INT32 << std::endl;
101+ std::cout << " NPY_UINT32 = " << NPY_UINT32 << std::endl;
102+ std::cout << " NPY_LONG = " << NPY_LONG << std::endl;
103+ std::cout << " NPY_ULONG = " << NPY_ULONG << std::endl;
104+ std::cout << " NPY_LONGLONG = " << NPY_LONGLONG << std::endl;
105+ std::cout << " NPY_ULONGLONG = " << NPY_ULONGLONG << std::endl;
106+ std::cout << " NPY_INT64 = " << NPY_INT64 << std::endl;
107+ std::cout << " NPY_UINT64 = " << NPY_UINT64 << std::endl;
108+ }
109+
71110PYBIND11_PLUGIN (xtensor_python_test)
72111{
73112 xt::import_numpy ();
@@ -93,5 +132,16 @@ PYBIND11_PLUGIN(xtensor_python_test)
93132 return a.shape () == b.shape ();
94133 });
95134
135+ m.def (" int_overload" , int_overload<uint8_t >);
136+ m.def (" int_overload" , int_overload<int8_t >);
137+ m.def (" int_overload" , int_overload<uint16_t >);
138+ m.def (" int_overload" , int_overload<int16_t >);
139+ m.def (" int_overload" , int_overload<uint32_t >);
140+ m.def (" int_overload" , int_overload<int32_t >);
141+ m.def (" int_overload" , int_overload<uint64_t >);
142+ m.def (" int_overload" , int_overload<int64_t >);
143+
144+ m.def (" dump_numpy_constant" , dump_numpy_constant);
145+
96146 return m.ptr ();
97147}
0 commit comments