Skip to content

Commit 35f6f08

Browse files
committed
fix shape initialization
1 parent dd4d875 commit 35f6f08

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ namespace xt
387387
const container_type& data_impl() const noexcept;
388388

389389
friend class xcontainer<pyarray<T>>;
390+
friend class pycontainer<pyarray<T>>;
390391
};
391392

392393
/**************************************

include/xtensor-python/pycontainer.hpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ namespace xt
120120
static bool check_(pybind11::handle h);
121121
static PyObject* raw_array_t(PyObject* ptr);
122122

123+
derived_type& derived_cast();
124+
const derived_type& derived_cast() const;
125+
123126
PyArrayObject* python_array() const;
124127
size_type get_min_stride() const;
125128
};
@@ -260,6 +263,19 @@ namespace xt
260263
return std::max(size_type(1), std::accumulate(this->strides().cbegin(), this->strides().cend(), std::numeric_limits<size_type>::max(), min));
261264
}
262265

266+
template <class D>
267+
inline auto pycontainer<D>::derived_cast() -> derived_type&
268+
{
269+
return *static_cast<derived_type*>(this);
270+
}
271+
272+
template <class D>
273+
inline auto pycontainer<D>::derived_cast() const -> const derived_type&
274+
{
275+
return *static_cast<const derived_type*>(this);
276+
}
277+
278+
263279
/**
264280
* resizes the container.
265281
* @param shape the new shape
@@ -330,7 +346,11 @@ namespace xt
330346
}
331347

332348
PyArray_Dims dims({reinterpret_cast<npy_intp*>(shape.data()), static_cast<int>(shape.size())});
333-
PyArray_Newshape((PyArrayObject*) this->ptr(), &dims, npy_layout);
349+
auto new_ptr = PyArray_Newshape((PyArrayObject*) this->ptr(), &dims, npy_layout);
350+
auto old_ptr = this->ptr();
351+
this->ptr() = new_ptr;
352+
Py_XDECREF(old_ptr);
353+
this->derived_cast().init_from_python();
334354
}
335355

336356
/**

include/xtensor-python/pytensor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ namespace xt
207207
const container_type& data_impl() const noexcept;
208208

209209
friend class xcontainer<pytensor<T, N>>;
210+
friend class pycontainer<pytensor<T, N>>;
210211
};
211212

212213
/***************************

test/test_pyarray.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,12 @@ namespace xt
220220
pyarray<int> a = {{1,2,3}, {4,5,6}};
221221
auto ptr = a.raw_data();
222222
a.reshape({1, 6});
223+
std::vector<std::size_t> sc1({1, 6});
224+
EXPECT_TRUE(std::equal(sc1.begin(), sc1.end(), a.shape().begin()) && a.shape().size() == 2);
223225
EXPECT_EQ(ptr, a.raw_data());
224226
a.reshape({6});
227+
std::vector<std::size_t> sc2 = {6};
228+
EXPECT_TRUE(std::equal(sc2.begin(), sc2.end(), a.shape().begin()) && a.shape().size() == 1);
225229
EXPECT_EQ(ptr, a.raw_data());
226230
}
227231
}

0 commit comments

Comments
 (0)