Skip to content

Commit a1fbdde

Browse files
authored
Merge pull request #132 from wolfv/fix_resize
Fix resize
2 parents d2668d3 + 35f6f08 commit a1fbdde

File tree

9 files changed

+139
-60
lines changed

9 files changed

+139
-60
lines changed

.appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ install:
2323
- conda update -q conda
2424
- conda info -a
2525
- conda install gtest cmake -c conda-forge
26-
- conda install xtensor==0.14.0 pytest numpy pybind11==2.2.1 -c conda-forge
26+
- conda install xtensor==0.15.0 pytest numpy pybind11==2.2.1 -c conda-forge
2727
- "set PYTHONHOME=%MINICONDA%"
2828
- cmake -G "NMake Makefiles" -D CMAKE_INSTALL_PREFIX=%MINICONDA%\\Library -D BUILD_TESTS=ON -D PYTHON_EXECUTABLE=%MINICONDA%\\python.exe .
2929
- nmake test_xtensor_python

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ install:
9595
- conda update -q conda
9696
# Useful for debugging any issues with conda
9797
- conda info -a
98-
- conda install xtensor==0.14.0 pytest numpy pybind11==2.2.1 -c conda-forge
98+
- conda install xtensor==0.15.0 pytest numpy pybind11==2.2.1 -c conda-forge
9999
- conda install cmake gtest -c conda-forge
100100
- cmake -D BUILD_TESTS=ON -D CMAKE_INSTALL_PREFIX=$HOME/miniconda .
101101
- make -j2 test_xtensor_python

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ from the `docs` subdirectory.
187187

188188
| `xtensor-python` | `xtensor` | `pybind11` |
189189
|------------------|-----------|------------------|
190-
| master | ^0.14.0 | ~2.1.0 or ~2.2.1 |
190+
| master | ^0.15.0 | ~2.1.0 or ~2.2.1 |
191191
| 0.16.x | ^0.14.0 | ~2.1.0 or ~2.2.1 |
192192
| 0.15.x | ^0.13.1 | ~2.1.0 or ~2.2.1 |
193193
| 0.14.x | ^0.12.0 | ~2.1.0 or ~2.2.1 |

include/xtensor-python/pyarray.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ namespace xt
387387
const container_type& data_impl() const noexcept;
388388

389389
friend class xcontainer<pyarray<T>>;
390+
friend class pycontainer<pyarray<T>>;
390391
};
391392

392393
/**************************************
@@ -486,47 +487,47 @@ namespace xt
486487
inline pyarray<T>::pyarray(const value_type& t)
487488
: base_type()
488489
{
489-
base_type::reshape(xt::shape<shape_type>(t), layout_type::row_major);
490+
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
490491
nested_copy(m_data.begin(), t);
491492
}
492493

493494
template <class T>
494495
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 1> t)
495496
: base_type()
496497
{
497-
base_type::reshape(xt::shape<shape_type>(t), layout_type::row_major);
498+
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
498499
nested_copy(m_data.begin(), t);
499500
}
500501

501502
template <class T>
502503
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 2> t)
503504
: base_type()
504505
{
505-
base_type::reshape(xt::shape<shape_type>(t), layout_type::row_major);
506+
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
506507
nested_copy(m_data.begin(), t);
507508
}
508509

509510
template <class T>
510511
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 3> t)
511512
: base_type()
512513
{
513-
base_type::reshape(xt::shape<shape_type>(t), layout_type::row_major);
514+
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
514515
nested_copy(m_data.begin(), t);
515516
}
516517

517518
template <class T>
518519
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 4> t)
519520
: base_type()
520521
{
521-
base_type::reshape(xt::shape<shape_type>(t), layout_type::row_major);
522+
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
522523
nested_copy(m_data.begin(), t);
523524
}
524525

525526
template <class T>
526527
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 5> t)
527528
: base_type()
528529
{
529-
base_type::reshape(xt::shape<shape_type>(t), layout_type::row_major);
530+
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
530531
nested_copy(m_data.begin(), t);
531532
}
532533

include/xtensor-python/pycontainer.hpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,14 @@ namespace xt
8585
static constexpr bool contiguous_layout = false;
8686

8787
template <class S = shape_type>
88-
void reshape(const S& shape);
88+
void resize(const S& shape);
8989
template <class S = shape_type>
90-
void reshape(const S& shape, layout_type l);
90+
void resize(const S& shape, layout_type l);
9191
template <class S = shape_type>
92-
void reshape(const S& shape, const strides_type& strides);
92+
void resize(const S& shape, const strides_type& strides);
93+
94+
template <class S = shape_type>
95+
void reshape(S&& shape, layout_type layout = base_type::static_layout);
9396

9497
layout_type layout() const;
9598

@@ -117,6 +120,9 @@ namespace xt
117120
static bool check_(pybind11::handle h);
118121
static PyObject* raw_array_t(PyObject* ptr);
119122

123+
derived_type& derived_cast();
124+
const derived_type& derived_cast() const;
125+
120126
PyArrayObject* python_array() const;
121127
size_type get_min_stride() const;
122128
};
@@ -257,47 +263,96 @@ namespace xt
257263
return std::max(size_type(1), std::accumulate(this->strides().cbegin(), this->strides().cend(), std::numeric_limits<size_type>::max(), min));
258264
}
259265

266+
template <class D>
267+
inline auto pycontainer<D>::derived_cast() -> derived_type&
268+
{
269+
return *static_cast<derived_type*>(this);
270+
}
271+
272+
template <class D>
273+
inline auto pycontainer<D>::derived_cast() const -> const derived_type&
274+
{
275+
return *static_cast<const derived_type*>(this);
276+
}
277+
278+
260279
/**
261-
* Reshapes the container.
280+
* resizes the container.
262281
* @param shape the new shape
263282
*/
264283
template <class D>
265284
template <class S>
266-
inline void pycontainer<D>::reshape(const S& shape)
285+
inline void pycontainer<D>::resize(const S& shape)
267286
{
268287
if (shape.size() != this->dimension() || !std::equal(std::begin(shape), std::end(shape), std::begin(this->shape())))
269288
{
270-
reshape(shape, layout_type::row_major);
289+
resize(shape, layout_type::row_major);
271290
}
272291
}
273292

274293
/**
275-
* Reshapes the container.
294+
* resizes the container.
276295
* @param shape the new shape
277296
* @param l the new layout
278297
*/
279298
template <class D>
280299
template <class S>
281-
inline void pycontainer<D>::reshape(const S& shape, layout_type l)
300+
inline void pycontainer<D>::resize(const S& shape, layout_type l)
282301
{
283302
strides_type strides = xtl::make_sequence<strides_type>(shape.size(), size_type(1));
284303
compute_strides(shape, l, strides);
285-
reshape(shape, strides);
304+
resize(shape, strides);
286305
}
287306

288307
/**
289-
* Reshapes the container.
308+
* resizes the container.
290309
* @param shape the new shape
291310
* @param strides the new strides
292311
*/
293312
template <class D>
294313
template <class S>
295-
inline void pycontainer<D>::reshape(const S& shape, const strides_type& strides)
314+
inline void pycontainer<D>::resize(const S& shape, const strides_type& strides)
296315
{
297316
derived_type tmp(xtl::forward_sequence<shape_type>(shape), strides);
298317
*static_cast<derived_type*>(this) = std::move(tmp);
299318
}
300319

320+
template <class D>
321+
template <class S>
322+
inline void pycontainer<D>::reshape(S&& shape, layout_type layout)
323+
{
324+
if (compute_size(shape) != this->size())
325+
{
326+
throw std::runtime_error("Cannot reshape with incorrect number of elements.");
327+
}
328+
329+
if (layout == layout_type::dynamic || layout == layout_type::any)
330+
{
331+
layout = DEFAULT_LAYOUT;
332+
}
333+
334+
NPY_ORDER npy_layout;
335+
if (layout == layout_type::row_major)
336+
{
337+
npy_layout = NPY_CORDER;
338+
}
339+
else if (layout == layout_type::column_major)
340+
{
341+
npy_layout = NPY_FORTRANORDER;
342+
}
343+
else
344+
{
345+
throw std::runtime_error("Cannot reshape with unknown layout_type.");
346+
}
347+
348+
PyArray_Dims dims({reinterpret_cast<npy_intp*>(shape.data()), static_cast<int>(shape.size())});
349+
auto new_ptr = PyArray_Newshape((PyArrayObject*) this->ptr(), &dims, npy_layout);
350+
auto old_ptr = this->ptr();
351+
this->ptr() = new_ptr;
352+
Py_XDECREF(old_ptr);
353+
this->derived_cast().init_from_python();
354+
}
355+
301356
/**
302357
* Return the layout_type of the container
303358
* @return layout_type of the container

include/xtensor-python/pytensor.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ namespace xt
207207
const container_type& data_impl() const noexcept;
208208

209209
friend class xcontainer<pytensor<T, N>>;
210+
friend class pycontainer<pytensor<T, N>>;
210211
};
211212

212213
/***************************
@@ -237,7 +238,7 @@ namespace xt
237238
inline pytensor<T, N>::pytensor(nested_initializer_list_t<T, N> t)
238239
: base_type()
239240
{
240-
base_type::reshape(xt::shape<shape_type>(t), layout_type::row_major);
241+
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
241242
nested_copy(m_data.begin(), t);
242243
}
243244

0 commit comments

Comments
 (0)