@@ -829,7 +829,7 @@ class usm_ndarray : public py::object
829829
830830 char * get_data () const
831831 {
832- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
832+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
833833
834834 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
835835 return api .UsmNDArray_GetData_ (raw_ar );
@@ -842,20 +842,29 @@ class usm_ndarray : public py::object
842842
843843 int get_ndim () const
844844 {
845- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
845+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
846846
847847 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
848848 return api .UsmNDArray_GetNDim_ (raw_ar );
849849 }
850850
851851 const py ::ssize_t * get_shape_raw () const
852852 {
853- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
853+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
854854
855855 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
856856 return api .UsmNDArray_GetShape_ (raw_ar );
857857 }
858858
859+ std ::vector < py ::ssize_t > get_shape_vector () const
860+ {
861+ auto raw_sh = get_shape_raw ();
862+ auto nd = get_ndim ();
863+
864+ std ::vector < py ::ssize_t > shape_vector (raw_sh , raw_sh + nd );
865+ return shape_vector ;
866+ }
867+
859868 py ::ssize_t get_shape (int i ) const
860869 {
861870 auto shape_ptr = get_shape_raw ();
@@ -864,15 +873,43 @@ class usm_ndarray : public py::object
864873
865874 const py ::ssize_t * get_strides_raw () const
866875 {
867- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
876+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
868877
869878 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
870879 return api .UsmNDArray_GetStrides_ (raw_ar );
871880 }
872881
882+ std ::vector < py ::ssize_t > get_strides_vector () const
883+ {
884+ auto raw_st = get_strides_raw ();
885+ auto nd = get_ndim ();
886+
887+ if (raw_st == nullptr ) {
888+ auto is_c_contig = is_c_contiguous ();
889+ auto is_f_contig = is_f_contiguous ();
890+ auto raw_sh = get_shape_raw ();
891+ if (is_c_contig ) {
892+ const auto & contig_strides = c_contiguous_strides (nd , raw_sh );
893+ return contig_strides ;
894+ }
895+ else if (is_f_contig ) {
896+ const auto & contig_strides = f_contiguous_strides (nd , raw_sh );
897+ return contig_strides ;
898+ }
899+ else {
900+ throw std ::runtime_error ("Invalid array encountered when "
901+ "building strides" );
902+ }
903+ }
904+ else {
905+ std ::vector < py ::ssize_t > st_vec (raw_st , raw_st + nd );
906+ return st_vec ;
907+ }
908+ }
909+
873910 py ::ssize_t get_size () const
874911 {
875- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
912+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
876913
877914 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
878915 int ndim = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -889,7 +926,7 @@ class usm_ndarray : public py::object
889926
890927 std ::pair < py ::ssize_t , py ::ssize_t > get_minmax_offsets () const
891928 {
892- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
929+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
893930
894931 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
895932 int nd = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -923,7 +960,7 @@ class usm_ndarray : public py::object
923960
924961 sycl ::queue get_queue () const
925962 {
926- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
963+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
927964
928965 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
929966 DPCTLSyclQueueRef QRef = api .UsmNDArray_GetQueueRef_ (raw_ar );
@@ -932,45 +969,45 @@ class usm_ndarray : public py::object
932969
933970 int get_typenum () const
934971 {
935- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
972+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
936973
937974 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
938975 return api .UsmNDArray_GetTypenum_ (raw_ar );
939976 }
940977
941978 int get_flags () const
942979 {
943- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
980+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
944981
945982 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
946983 return api .UsmNDArray_GetFlags_ (raw_ar );
947984 }
948985
949986 int get_elemsize () const
950987 {
951- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
988+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
952989
953990 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
954991 return api .UsmNDArray_GetElementSize_ (raw_ar );
955992 }
956993
957994 bool is_c_contiguous () const
958995 {
959- int flags = this -> get_flags ();
996+ int flags = get_flags ();
960997 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
961998 return static_cast < bool > (flags & api .USM_ARRAY_C_CONTIGUOUS_ );
962999 }
9631000
9641001 bool is_f_contiguous () const
9651002 {
966- int flags = this -> get_flags ();
1003+ int flags = get_flags ();
9671004 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
9681005 return static_cast < bool > (flags & api .USM_ARRAY_F_CONTIGUOUS_ );
9691006 }
9701007
9711008 bool is_writable () const
9721009 {
973- int flags = this -> get_flags ();
1010+ int flags = get_flags ();
9741011 auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
9751012 return static_cast < bool > (flags & api .USM_ARRAY_WRITABLE_ );
9761013 }
0 commit comments