@@ -85,11 +85,14 @@ namespace xt
8585 static constexpr bool contiguous_layout = false ;
8686
8787 template <class S = shape_type>
88- void reshape (const S& shape);
88+ void resize (const S& shape);
8989 template <class S = shape_type>
90- void reshape (const S& shape, layout_type l);
90+ void resize (const S& shape, layout_type l);
9191 template <class S = shape_type>
92- void reshape (const S& shape, const strides_type& strides);
92+ void resize (const S& shape, const strides_type& strides);
93+
94+ template <class S = shape_type>
95+ void reshape (S&& shape, layout_type layout = base_type::static_layout);
9396
9497 layout_type layout () const ;
9598
@@ -117,6 +120,9 @@ namespace xt
117120 static bool check_ (pybind11::handle h);
118121 static PyObject* raw_array_t (PyObject* ptr);
119122
123+ derived_type& derived_cast ();
124+ const derived_type& derived_cast () const ;
125+
120126 PyArrayObject* python_array () const ;
121127 size_type get_min_stride () const ;
122128 };
@@ -257,47 +263,96 @@ namespace xt
257263 return std::max (size_type (1 ), std::accumulate (this ->strides ().cbegin (), this ->strides ().cend (), std::numeric_limits<size_type>::max (), min));
258264 }
259265
266+ template <class D >
267+ inline auto pycontainer<D>::derived_cast() -> derived_type&
268+ {
269+ return *static_cast <derived_type*>(this );
270+ }
271+
272+ template <class D >
273+ inline auto pycontainer<D>::derived_cast() const -> const derived_type&
274+ {
275+ return *static_cast <const derived_type*>(this );
276+ }
277+
278+
260279 /* *
261- * Reshapes the container.
280+ * resizes the container.
262281 * @param shape the new shape
263282 */
264283 template <class D >
265284 template <class S >
266- inline void pycontainer<D>::reshape (const S& shape)
285+ inline void pycontainer<D>::resize (const S& shape)
267286 {
268287 if (shape.size () != this ->dimension () || !std::equal (std::begin (shape), std::end (shape), std::begin (this ->shape ())))
269288 {
270- reshape (shape, layout_type::row_major);
289+ resize (shape, layout_type::row_major);
271290 }
272291 }
273292
274293 /* *
275- * Reshapes the container.
294+ * resizes the container.
276295 * @param shape the new shape
277296 * @param l the new layout
278297 */
279298 template <class D >
280299 template <class S >
281- inline void pycontainer<D>::reshape (const S& shape, layout_type l)
300+ inline void pycontainer<D>::resize (const S& shape, layout_type l)
282301 {
283302 strides_type strides = xtl::make_sequence<strides_type>(shape.size (), size_type (1 ));
284303 compute_strides (shape, l, strides);
285- reshape (shape, strides);
304+ resize (shape, strides);
286305 }
287306
288307 /* *
289- * Reshapes the container.
308+ * resizes the container.
290309 * @param shape the new shape
291310 * @param strides the new strides
292311 */
293312 template <class D >
294313 template <class S >
295- inline void pycontainer<D>::reshape (const S& shape, const strides_type& strides)
314+ inline void pycontainer<D>::resize (const S& shape, const strides_type& strides)
296315 {
297316 derived_type tmp (xtl::forward_sequence<shape_type>(shape), strides);
298317 *static_cast <derived_type*>(this ) = std::move (tmp);
299318 }
300319
320+ template <class D >
321+ template <class S >
322+ inline void pycontainer<D>::reshape(S&& shape, layout_type layout)
323+ {
324+ if (compute_size (shape) != this ->size ())
325+ {
326+ throw std::runtime_error (" Cannot reshape with incorrect number of elements." );
327+ }
328+
329+ if (layout == layout_type::dynamic || layout == layout_type::any)
330+ {
331+ layout = DEFAULT_LAYOUT;
332+ }
333+
334+ NPY_ORDER npy_layout;
335+ if (layout == layout_type::row_major)
336+ {
337+ npy_layout = NPY_CORDER;
338+ }
339+ else if (layout == layout_type::column_major)
340+ {
341+ npy_layout = NPY_FORTRANORDER;
342+ }
343+ else
344+ {
345+ throw std::runtime_error (" Cannot reshape with unknown layout_type." );
346+ }
347+
348+ PyArray_Dims dims ({reinterpret_cast <npy_intp*>(shape.data ()), static_cast <int >(shape.size ())});
349+ auto new_ptr = PyArray_Newshape ((PyArrayObject*) this ->ptr (), &dims, npy_layout);
350+ auto old_ptr = this ->ptr ();
351+ this ->ptr () = new_ptr;
352+ Py_XDECREF (old_ptr);
353+ this ->derived_cast ().init_from_python ();
354+ }
355+
301356 /* *
302357 * Return the layout_type of the container
303358 * @return layout_type of the container
0 commit comments