@@ -856,6 +856,15 @@ class usm_ndarray : public py::object
856856 return api .UsmNDArray_GetShape_ (raw_ar );
857857 }
858858
859+ const std ::vector < py ::ssize_t > get_shape_vector () const
860+ {
861+ auto raw_sh = this -> get_shape_raw ();
862+ auto nd = this -> get_ndim ();
863+
864+ std ::vector < py ::ssize_t > shape_vector (raw_sh , raw_sh + nd );
865+ return shape_vector ;
866+ }
867+
859868 py ::ssize_t get_shape (int i ) const
860869 {
861870 auto shape_ptr = get_shape_raw ();
@@ -870,6 +879,34 @@ class usm_ndarray : public py::object
870879 return api .UsmNDArray_GetStrides_ (raw_ar );
871880 }
872881
882+ const std ::vector < py ::ssize_t > get_strides_vector () const
883+ {
884+ auto raw_st = this -> get_strides_raw ();
885+ auto nd = this -> get_ndim ();
886+
887+ if (raw_st == nullptr) {
888+ auto is_c_contig = this -> is_c_contiguous ();
889+ auto is_f_contig = this -> is_f_contiguous ();
890+ auto raw_sh = this -> get_shape_raw ();
891+ if (is_c_contig ) {
892+ const auto & contig_strides = c_contiguous_strides (nd , raw_sh );
893+ return contig_strides ;
894+ }
895+ else if (is_f_contig ) {
896+ const auto & contig_strides = f_contiguous_strides (nd , raw_sh );
897+ return contig_strides ;
898+ }
899+ else {
900+ throw std ::runtime_error ("Invalid array encountered when "
901+ "building strides" );
902+ }
903+ }
904+ else {
905+ std ::vector < py ::ssize_t > st_vec (raw_st , raw_st + nd );
906+ return st_vec ;
907+ }
908+ }
909+
873910 py ::ssize_t get_size () const
874911 {
875912 PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
0 commit comments