Skip to content

Commit 1d9498c

Browse files
committed
template reshape
1 parent f1a64a8 commit 1d9498c

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

include/xtensor-python/pycontainer.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@ namespace xt
8181
static constexpr layout_type static_layout = layout_type::dynamic;
8282
static constexpr bool contiguous_layout = false;
8383

84-
void reshape(const shape_type& shape);
85-
void reshape(const shape_type& shape, layout_type l);
86-
void reshape(const shape_type& shape, const strides_type& strides);
84+
template <class S = shape_type>
85+
void reshape(const S& shape);
86+
template <class S = shape_type>
87+
void reshape(const S& shape, layout_type l);
88+
template <class S = shape_type>
89+
void reshape(const S& shape, const strides_type& strides);
8790

8891
layout_type layout() const;
8992

@@ -219,7 +222,8 @@ namespace xt
219222
* @param shape the new shape
220223
*/
221224
template <class D>
222-
inline void pycontainer<D>::reshape(const shape_type& shape)
225+
template <class S>
226+
inline void pycontainer<D>::reshape(const S& shape)
223227
{
224228
if (shape.size() != this->dimension() || !std::equal(shape.begin(), shape.end(), this->shape().begin()))
225229
{
@@ -233,7 +237,8 @@ namespace xt
233237
* @param l the new layout
234238
*/
235239
template <class D>
236-
inline void pycontainer<D>::reshape(const shape_type& shape, layout_type l)
240+
template <class S>
241+
inline void pycontainer<D>::reshape(const S& shape, layout_type l)
237242
{
238243
strides_type strides = xtl::make_sequence<strides_type>(shape.size(), size_type(1));
239244
compute_strides(shape, l, strides);
@@ -246,7 +251,8 @@ namespace xt
246251
* @param strides the new strides
247252
*/
248253
template <class D>
249-
inline void pycontainer<D>::reshape(const shape_type& shape, const strides_type& strides)
254+
template <class S>
255+
inline void pycontainer<D>::reshape(const S& shape, const strides_type& strides)
250256
{
251257
derived_type tmp(shape, strides);
252258
*static_cast<derived_type*>(this) = std::move(tmp);

0 commit comments

Comments
 (0)