@@ -744,14 +744,18 @@ static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj)
744744 DPEXRT_DEBUG (
745745 drt_debug_print ("DPEXRT-DEBUG: usm array was passed directly\n" ));
746746 arrayobj = obj ;
747+ Py_INCREF (arrayobj );
747748 }
748749 else if (PyObject_HasAttrString (obj , "_array_obj" )) {
750+ // PyObject_GetAttrString gives reference
749751 arrayobj = PyObject_GetAttrString (obj , "_array_obj" );
750752
751753 if (!arrayobj )
752754 return NULL ;
753- if (!PyObject_TypeCheck (arrayobj , & PyUSMArrayType ))
755+ if (!PyObject_TypeCheck (arrayobj , & PyUSMArrayType )) {
756+ Py_DECREF (arrayobj );
754757 return NULL ;
758+ }
755759 }
756760
757761 struct PyUSMArrayObject * pyusmarrayobj =
@@ -803,17 +807,13 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
803807 PyGILState_STATE gstate ;
804808 npy_intp itemsize = 0 ;
805809
806- // Increment the ref count on obj to prevent CPython from garbage
807- // collecting the array.
808- // TODO: add extra description why do we need this
809- Py_IncRef (obj );
810-
811810 DPEXRT_DEBUG (drt_debug_print (
812811 "DPEXRT-DEBUG: In DPEXRT_sycl_usm_ndarray_from_python at %s, line %d\n" ,
813812 __FILE__ , __LINE__ ));
814813
815814 // Check if the PyObject obj has an _array_obj attribute that is of
816815 // dpctl.tensor.usm_ndarray type.
816+ // arrayobj is a new reference, reference of obj is borrowed
817817 if (!(arrayobj = PyUSMNdArray_ARRAYOBJ (obj ))) {
818818 DPEXRT_DEBUG (drt_debug_print (
819819 "DPEXRT-ERROR: PyUSMNdArray_ARRAYOBJ check failed at %s, line %d\n" ,
@@ -832,6 +832,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
832832 data = (void * )UsmNDArray_GetData (arrayobj );
833833 nitems = product_of_shape (shape , ndim );
834834 itemsize = (npy_intp )UsmNDArray_GetElementSize (arrayobj );
835+
835836 if (!(qref = UsmNDArray_GetQueueRef (arrayobj ))) {
836837 DPEXRT_DEBUG (drt_debug_print (
837838 "DPEXRT-ERROR: UsmNDArray_GetQueueRef returned NULL at "
@@ -841,7 +842,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
841842 }
842843
843844 if (!(arystruct -> meminfo = NRT_MemInfo_new_from_usmndarray (
844- obj , data , nitems , itemsize , qref )))
845+ arrayobj , data , nitems , itemsize , qref )))
845846 {
846847 DPEXRT_DEBUG (drt_debug_print (
847848 "DPEXRT-ERROR: NRT_MemInfo_new_from_usmndarray failed "
@@ -854,7 +855,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
854855 arystruct -> sycl_queue = qref ;
855856 arystruct -> nitems = nitems ;
856857 arystruct -> itemsize = itemsize ;
857- arystruct -> parent = obj ;
858+ arystruct -> parent = arrayobj ;
858859
859860 p = arystruct -> shape_and_strides ;
860861
@@ -906,7 +907,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
906907 __FILE__ , __LINE__ ));
907908 gstate = PyGILState_Ensure ();
908909 // decref the python object
909- Py_DECREF ( obj );
910+ Py_XDECREF ( arrayobj );
910911 // release the GIL
911912 PyGILState_Release (gstate );
912913
0 commit comments