Skip to content

Commit c0da58b

Browse files
Memory class stores SyclQueue inside to enable copying to host
1 parent 26b2a8d commit c0da58b

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

dppl/_memory.pyx

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ from cpython cimport Py_buffer
88
cdef class Memory:
99
cdef DPPLSyclUSMRef memory_ptr
1010
cdef Py_ssize_t nbytes
11-
cdef SyclContext context
11+
cdef SyclQueue queue
1212

1313
cdef _cinit(self, Py_ssize_t nbytes, ptr_type, SyclQueue queue):
1414
cdef DPPLSyclUSMRef p
1515

1616
self.memory_ptr = NULL
1717
self.nbytes = 0
18-
self.context = None
18+
self.queue = None
1919

2020
if (nbytes > 0):
2121
if queue is None:
@@ -34,19 +34,19 @@ cdef class Memory:
3434
if (p):
3535
self.memory_ptr = p
3636
self.nbytes = nbytes
37-
self.context = queue.get_sycl_context()
37+
self.queue = queue
3838
else:
3939
raise RuntimeError("Null memory pointer returned")
4040
else:
4141
raise ValueError("Non-positive number of bytes found.")
4242

4343
def __dealloc__(self):
4444
if (self.memory_ptr):
45-
DPPLfree_with_context(self.memory_ptr,
46-
self.context.get_context_ref())
45+
DPPLfree_with_queue(self.memory_ptr,
46+
self.queue.get_queue_ref())
4747
self.memory_ptr = NULL
4848
self.nbytes = 0
49-
self.context = None
49+
self.queue = None
5050

5151
cdef _getbuffer(self, Py_buffer *buffer, int flags):
5252
# memory_ptr is Ref which is pointer to SYCL type. For USM it is void*.
@@ -68,21 +68,31 @@ cdef class Memory:
6868

6969
property _context:
7070
def __get__(self):
71-
return self.context
71+
return self.queue.get_sycl_context()
72+
73+
property _queue:
74+
def __get__(self):
75+
return self.queue
7276

7377
def __repr__(self):
7478
return "<Intel(R) USM allocated memory block of {} bytes at {}>" \
7579
.format(self.nbytes, hex(<object>(<Py_ssize_t>self.memory_ptr)))
7680

77-
def _usm_type(self, sycl_context=None):
81+
def _usm_type(self, context=None):
7882
cdef const char* kind
7983
cdef SyclContext ctx
80-
if sycl_context is None:
81-
ctx = self.context
84+
cdef SyclQueue q
85+
if context is None:
86+
ctx = self._context
87+
kind = DPPLUSM_GetPointerType(self.memory_ptr,
88+
ctx.get_context_ref())
89+
elif isinstance(context, SyclContext):
90+
ctx = <SyclContext>(context)
8291
kind = DPPLUSM_GetPointerType(self.memory_ptr,
8392
ctx.get_context_ref())
84-
elif isinstance(sycl_context, SyclContext):
85-
ctx = <SyclContext>(sycl_context)
93+
elif isinstance(context, SyclQueue):
94+
q = <SyclQueue>(context)
95+
ctx = q.get_sycl_context()
8696
kind = DPPLUSM_GetPointerType(self.memory_ptr,
8797
ctx.get_context_ref())
8898
else:

0 commit comments

Comments
 (0)