2323
2424namespace xt
2525{
26- template <class T >
26+ template <class T , layout_type L = layout_type::dynamic >
2727 class pyarray ;
2828}
2929
3030namespace pybind11
3131{
3232 namespace detail
3333 {
34- template <class T >
35- struct handle_type_name <xt::pyarray<T>>
34+ template <class T , xt::layout_type L >
35+ struct handle_type_name <xt::pyarray<T, L >>
3636 {
3737 static PYBIND11_DESCR name ()
3838 {
3939 return _ (" numpy.ndarray[" ) + npy_format_descriptor<T>::name () + _ (" ]" );
4040 }
4141 };
4242
43- template <typename T>
44- struct pyobject_caster <xt::pyarray<T>>
43+ template <typename T, xt::layout_type L >
44+ struct pyobject_caster <xt::pyarray<T, L >>
4545 {
46- using type = xt::pyarray<T>;
46+ using type = xt::pyarray<T, L >;
4747
4848 bool load (handle src, bool convert)
4949 {
@@ -72,10 +72,10 @@ namespace pybind11
7272 };
7373
7474 // Type caster for casting ndarray to xexpression<pyarray>
75- template <typename T>
76- struct type_caster <xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>
75+ template <typename T, xt::layout_type L >
76+ struct type_caster <xt::xexpression<xt::pyarray<T, L >>> : pyobject_caster<xt::pyarray<T, L >>
7777 {
78- using Type = xt::xexpression<xt::pyarray<T>>;
78+ using Type = xt::xexpression<xt::pyarray<T, L >>;
7979
8080 operator Type&()
8181 {
@@ -89,8 +89,8 @@ namespace pybind11
8989 };
9090
9191 // Type caster for casting xarray to ndarray
92- template <class T >
93- struct type_caster <xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>>
92+ template <class T , xt::layout_type L >
93+ struct type_caster <xt::xarray<T, L >> : xtensor_type_caster_base<xt::xarray<T, L >>
9494 {
9595 };
9696 }
@@ -282,24 +282,24 @@ namespace xt
282282 const array_type* p_a;
283283 };
284284
285- template <class T >
286- struct xiterable_inner_types <pyarray<T>>
287- : xcontainer_iterable_types<pyarray<T>>
285+ template <class T , layout_type L >
286+ struct xiterable_inner_types <pyarray<T, L >>
287+ : xcontainer_iterable_types<pyarray<T, L >>
288288 {
289289 };
290290
291- template <class T >
292- struct xcontainer_inner_types <pyarray<T>>
291+ template <class T , layout_type L >
292+ struct xcontainer_inner_types <pyarray<T, L >>
293293 {
294294 using storage_type = xbuffer_adaptor<T*>;
295295 using shape_type = std::vector<typename storage_type::size_type>;
296296 using strides_type = shape_type;
297- using backstrides_type = pyarray_backstrides<pyarray<T>>;
297+ using backstrides_type = pyarray_backstrides<pyarray<T, L >>;
298298 using inner_shape_type = xbuffer_adaptor<std::size_t *>;
299299 using inner_strides_type = pystrides_adaptor<sizeof (T)>;
300300 using inner_backstrides_type = backstrides_type;
301- using temporary_type = pyarray<T>;
302- static constexpr layout_type layout = layout_type::dynamic ;
301+ using temporary_type = pyarray<T, L >;
302+ static constexpr layout_type layout = L ;
303303 };
304304
305305 /* *
@@ -312,13 +312,13 @@ namespace xt
312312 * @tparam T The type of the element stored in the pyarray.
313313 * @sa pytensor
314314 */
315- template <class T >
316- class pyarray : public pycontainer <pyarray<T>>,
317- public xcontainer_semantic<pyarray<T>>
315+ template <class T , layout_type L >
316+ class pyarray : public pycontainer <pyarray<T, L >>,
317+ public xcontainer_semantic<pyarray<T, L >>
318318 {
319319 public:
320320
321- using self_type = pyarray<T>;
321+ using self_type = pyarray<T, L >;
322322 using semantic_base = xcontainer_semantic<self_type>;
323323 using base_type = pycontainer<self_type>;
324324 using storage_type = typename base_type::storage_type;
@@ -386,8 +386,8 @@ namespace xt
386386 storage_type& storage_impl () noexcept ;
387387 const storage_type& storage_impl () const noexcept ;
388388
389- friend class xcontainer <pyarray<T>>;
390- friend class pycontainer <pyarray<T>>;
389+ friend class xcontainer <pyarray<T, L >>;
390+ friend class pycontainer <pyarray<T, L >>;
391391 };
392392
393393 /* *************************************
@@ -469,8 +469,8 @@ namespace xt
469469 * @name Constructors
470470 */
471471 // @{
472- template <class T >
473- inline pyarray<T>::pyarray()
472+ template <class T , layout_type L >
473+ inline pyarray<T, L >::pyarray()
474474 : base_type()
475475 {
476476 // TODO: avoid allocation
@@ -483,70 +483,70 @@ namespace xt
483483 /* *
484484 * Allocates a pyarray with nested initializer lists.
485485 */
486- template <class T >
487- inline pyarray<T>::pyarray(const value_type& t)
486+ template <class T , layout_type L >
487+ inline pyarray<T, L >::pyarray(const value_type& t)
488488 : base_type()
489489 {
490490 base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
491491 nested_copy (m_storage.begin (), t);
492492 }
493493
494- template <class T >
495- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 1 > t)
494+ template <class T , layout_type L >
495+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 1 > t)
496496 : base_type()
497497 {
498498 base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
499499 nested_copy (m_storage.begin (), t);
500500 }
501501
502- template <class T >
503- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 2 > t)
502+ template <class T , layout_type L >
503+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 2 > t)
504504 : base_type()
505505 {
506506 base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
507507 nested_copy (m_storage.begin (), t);
508508 }
509509
510- template <class T >
511- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 3 > t)
510+ template <class T , layout_type L >
511+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 3 > t)
512512 : base_type()
513513 {
514514 base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
515515 nested_copy (m_storage.begin (), t);
516516 }
517517
518- template <class T >
519- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 4 > t)
518+ template <class T , layout_type L >
519+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 4 > t)
520520 : base_type()
521521 {
522522 base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
523523 nested_copy (m_storage.begin (), t);
524524 }
525525
526- template <class T >
527- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 5 > t)
526+ template <class T , layout_type L >
527+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 5 > t)
528528 : base_type()
529529 {
530530 base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
531531 nested_copy (m_storage.begin (), t);
532532 }
533533
534- template <class T >
535- inline pyarray<T>::pyarray(pybind11::handle h, pybind11::object::borrowed_t b)
534+ template <class T , layout_type L >
535+ inline pyarray<T, L >::pyarray(pybind11::handle h, pybind11::object::borrowed_t b)
536536 : base_type(h, b)
537537 {
538538 init_from_python ();
539539 }
540540
541- template <class T >
542- inline pyarray<T>::pyarray(pybind11::handle h, pybind11::object::stolen_t s)
541+ template <class T , layout_type L >
542+ inline pyarray<T, L >::pyarray(pybind11::handle h, pybind11::object::stolen_t s)
543543 : base_type(h, s)
544544 {
545545 init_from_python ();
546546 }
547547
548- template <class T >
549- inline pyarray<T>::pyarray(const pybind11::object& o)
548+ template <class T , layout_type L >
549+ inline pyarray<T, L >::pyarray(const pybind11::object& o)
550550 : base_type(o)
551551 {
552552 init_from_python ();
@@ -558,8 +558,8 @@ namespace xt
558558 * @param shape the shape of the pyarray
559559 * @param l the layout of the pyarray
560560 */
561- template <class T >
562- inline pyarray<T>::pyarray(const shape_type& shape, layout_type l)
561+ template <class T , layout_type L >
562+ inline pyarray<T, L >::pyarray(const shape_type& shape, layout_type l)
563563 : base_type()
564564 {
565565 strides_type strides (shape.size ());
@@ -574,8 +574,8 @@ namespace xt
574574 * @param value the value of the elements
575575 * @param l the layout of the pyarray
576576 */
577- template <class T >
578- inline pyarray<T>::pyarray(const shape_type& shape, const_reference value, layout_type l)
577+ template <class T , layout_type L >
578+ inline pyarray<T, L >::pyarray(const shape_type& shape, const_reference value, layout_type l)
579579 : base_type()
580580 {
581581 strides_type strides (shape.size ());
@@ -591,8 +591,8 @@ namespace xt
591591 * @param strides the strides of the pyarray
592592 * @param value the value of the elements
593593 */
594- template <class T >
595- inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
594+ template <class T , layout_type L >
595+ inline pyarray<T, L >::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
596596 : base_type()
597597 {
598598 init_array (shape, strides);
@@ -604,8 +604,8 @@ namespace xt
604604 * @param shape the shape of the pyarray
605605 * @param strides the strides of the pyarray
606606 */
607- template <class T >
608- inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides)
607+ template <class T , layout_type L >
608+ inline pyarray<T, L >::pyarray(const shape_type& shape, const strides_type& strides)
609609 : base_type()
610610 {
611611 init_array (shape, strides);
@@ -619,8 +619,8 @@ namespace xt
619619 /* *
620620 * The copy constructor.
621621 */
622- template <class T >
623- inline pyarray<T>::pyarray(const self_type& rhs)
622+ template <class T , layout_type L >
623+ inline pyarray<T, L >::pyarray(const self_type& rhs)
624624 : base_type(), semantic_base(rhs)
625625 {
626626 auto tmp = pybind11::reinterpret_steal<pybind11::object>(
@@ -639,8 +639,8 @@ namespace xt
639639 /* *
640640 * The assignment operator.
641641 */
642- template <class T >
643- inline auto pyarray<T>::operator =(const self_type& rhs) -> self_type&
642+ template <class T , layout_type L >
643+ inline auto pyarray<T, L >::operator =(const self_type& rhs) -> self_type&
644644 {
645645 self_type tmp (rhs);
646646 *this = std::move (tmp);
@@ -656,9 +656,9 @@ namespace xt
656656 /* *
657657 * The extended copy constructor.
658658 */
659- template <class T >
659+ template <class T , layout_type L >
660660 template <class E >
661- inline pyarray<T>::pyarray(const xexpression<E>& e)
661+ inline pyarray<T, L >::pyarray(const xexpression<E>& e)
662662 : base_type()
663663 {
664664 // TODO: prevent intermediary shape allocation
@@ -672,28 +672,28 @@ namespace xt
672672 /* *
673673 * The extended assignment operator.
674674 */
675- template <class T >
675+ template <class T , layout_type L >
676676 template <class E >
677- inline auto pyarray<T>::operator =(const xexpression<E>& e) -> self_type&
677+ inline auto pyarray<T, L >::operator =(const xexpression<E>& e) -> self_type&
678678 {
679679 return semantic_base::operator =(e);
680680 }
681681 // @}
682682
683- template <class T >
684- inline auto pyarray<T>::ensure(pybind11::handle h) -> self_type
683+ template <class T , layout_type L >
684+ inline auto pyarray<T, L >::ensure(pybind11::handle h) -> self_type
685685 {
686686 return base_type::ensure (h);
687687 }
688688
689- template <class T >
690- inline bool pyarray<T>::check_(pybind11::handle h)
689+ template <class T , layout_type L >
690+ inline bool pyarray<T, L >::check_(pybind11::handle h)
691691 {
692692 return base_type::check_ (h);
693693 }
694694
695- template <class T >
696- inline void pyarray<T>::init_array(const shape_type& shape, const strides_type& strides)
695+ template <class T , layout_type L >
696+ inline void pyarray<T, L >::init_array(const shape_type& shape, const strides_type& strides)
697697 {
698698 strides_type adapted_strides (strides);
699699
@@ -722,8 +722,8 @@ namespace xt
722722 init_from_python ();
723723 }
724724
725- template <class T >
726- inline void pyarray<T>::init_from_python()
725+ template <class T , layout_type L >
726+ inline void pyarray<T, L >::init_from_python()
727727 {
728728 m_shape = inner_shape_type (reinterpret_cast <size_type*>(PyArray_SHAPE (this ->python_array ())),
729729 static_cast <size_type>(PyArray_NDIM (this ->python_array ())));
@@ -734,20 +734,20 @@ namespace xt
734734 this ->get_min_stride () * static_cast <size_type>(PyArray_SIZE (this ->python_array ())));
735735 }
736736
737- template <class T >
738- inline auto pyarray<T>::shape_impl() const noexcept -> const inner_shape_type&
737+ template <class T , layout_type L >
738+ inline auto pyarray<T, L >::shape_impl() const noexcept -> const inner_shape_type&
739739 {
740740 return m_shape;
741741 }
742742
743- template <class T >
744- inline auto pyarray<T>::strides_impl() const noexcept -> const inner_strides_type&
743+ template <class T , layout_type L >
744+ inline auto pyarray<T, L >::strides_impl() const noexcept -> const inner_strides_type&
745745 {
746746 return m_strides;
747747 }
748748
749- template <class T >
750- inline auto pyarray<T>::backstrides_impl() const noexcept -> const inner_backstrides_type&
749+ template <class T , layout_type L >
750+ inline auto pyarray<T, L >::backstrides_impl() const noexcept -> const inner_backstrides_type&
751751 {
752752 // m_backstrides wraps the numpy array backstrides, which is a raw pointer.
753753 // The address of the raw pointer stored in the wrapper would be invalidated when the pyarray is copied.
@@ -756,14 +756,14 @@ namespace xt
756756 return m_backstrides;
757757 }
758758
759- template <class T >
760- inline auto pyarray<T>::storage_impl() noexcept -> storage_type&
759+ template <class T , layout_type L >
760+ inline auto pyarray<T, L >::storage_impl() noexcept -> storage_type&
761761 {
762762 return m_storage;
763763 }
764764
765- template <class T >
766- inline auto pyarray<T>::storage_impl() const noexcept -> const storage_type&
765+ template <class T , layout_type L >
766+ inline auto pyarray<T, L >::storage_impl() const noexcept -> const storage_type&
767767 {
768768 return m_storage;
769769 }
0 commit comments