@@ -829,7 +829,7 @@ class usm_ndarray : public py::object
829829
830830 char * get_data () const
831831 {
832- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
832+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
833833
834834 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
835835 return api .UsmNDArray_GetData_ (raw_ar );
@@ -842,20 +842,29 @@ class usm_ndarray : public py::object
842842
843843 int get_ndim () const
844844 {
845- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
845+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
846846
847847 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
848848 return api .UsmNDArray_GetNDim_ (raw_ar );
849849 }
850850
851851 const py ::ssize_t * get_shape_raw () const
852852 {
853- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
853+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
854854
855855 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
856856 return api .UsmNDArray_GetShape_ (raw_ar );
857857 }
858858
859+ std ::vector < py ::ssize_t > get_shape_vector () const
860+ {
861+ auto raw_sh = get_shape_raw ();
862+ auto nd = get_ndim ();
863+
864+ std ::vector < py ::ssize_t > shape_vector (raw_sh , raw_sh + nd );
865+ return shape_vector ;
866+ }
867+
859868 py ::ssize_t get_shape (int i ) const
860869 {
861870 auto shape_ptr = get_shape_raw ();
@@ -864,15 +873,43 @@ class usm_ndarray : public py::object
864873
865874 const py ::ssize_t * get_strides_raw () const
866875 {
867- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
876+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
868877
869878 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
870879 return api .UsmNDArray_GetStrides_ (raw_ar );
871880 }
872881
882+ std ::vector < py ::ssize_t > get_strides_vector () const
883+ {
884+ auto raw_st = get_strides_raw ();
885+ auto nd = get_ndim ();
886+
887+ if (raw_st == nullptr ) {
888+ auto is_c_contig = is_c_contiguous ();
889+ auto is_f_contig = is_f_contiguous ();
890+ auto raw_sh = get_shape_raw ();
891+ if (is_c_contig ) {
892+ const auto & contig_strides = c_contiguous_strides (nd , raw_sh );
893+ return contig_strides ;
894+ }
895+ else if (is_f_contig ) {
896+ const auto & contig_strides = f_contiguous_strides (nd , raw_sh );
897+ return contig_strides ;
898+ }
899+ else {
900+ throw std ::runtime_error ("Invalid array encountered when "
901+ "building strides" );
902+ }
903+ }
904+ else {
905+ std ::vector < py ::ssize_t > st_vec (raw_st , raw_st + nd );
906+ return st_vec ;
907+ }
908+ }
909+
873910 py ::ssize_t get_size () const
874911 {
875- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
912+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
876913
877914 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
878915 int ndim = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -889,7 +926,7 @@ class usm_ndarray : public py::object
889926
890927 std ::pair < py ::ssize_t , py ::ssize_t > get_minmax_offsets () const
891928 {
892- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
929+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
893930
894931 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
895932 int nd = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -906,8 +943,6 @@ class usm_ndarray : public py::object
906943 }
907944 }
908945 else {
909- offset_min = api .UsmNDArray_GetOffset_ (raw_ar );
910- offset_max = offset_min ;
911946 for (int i = 0 ; i < nd ; ++ i ) {
912947 py ::ssize_t delta = strides [i ] * (shape [i ] - 1 );
913948 if (strides [i ] > 0 ) {
@@ -923,7 +958,7 @@ class usm_ndarray : public py::object
923958
924959 sycl ::queue get_queue () const
925960 {
926- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
961+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
927962
928963 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
929964 DPCTLSyclQueueRef QRef = api .UsmNDArray_GetQueueRef_ (raw_ar );
@@ -932,45 +967,45 @@ class usm_ndarray : public py::object
932967
933968 int get_typenum () const
934969 {
935- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
970+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
936971
937972 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
938973 return api .UsmNDArray_GetTypenum_ (raw_ar );
939974 }
940975
941976 int get_flags () const
942977 {
943- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
978+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
944979
945980 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
946981 return api .UsmNDArray_GetFlags_ (raw_ar );
947982 }
948983
949984 int get_elemsize () const
950985 {
951- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
986+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
952987
953988 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
954989 return api .UsmNDArray_GetElementSize_ (raw_ar );
955990 }
956991
957992 bool is_c_contiguous () const
958993 {
959- int flags = this -> get_flags ();
994+ int flags = get_flags ();
960995 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
961996 return static_cast < bool > (flags & api .USM_ARRAY_C_CONTIGUOUS_ );
962997 }
963998
964999 bool is_f_contiguous () const
9651000 {
966- int flags = this -> get_flags ();
1001+ int flags = get_flags ();
9671002 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
9681003 return static_cast < bool > (flags & api .USM_ARRAY_F_CONTIGUOUS_ );
9691004 }
9701005
9711006 bool is_writable () const
9721007 {
973- int flags = this -> get_flags ();
1008+ int flags = get_flags ();
9741009 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
9751010 return static_cast < bool > (flags & api .USM_ARRAY_WRITABLE_ );
9761011 }
0 commit comments