@@ -30,16 +30,16 @@ cimport cython.array
3030from dpctl._backend cimport ( # noqa: E211, E402
3131 DPCTLCString_Delete,
3232 DPCTLKernel_Delete,
33- DPCTLKernel_GetFunctionName,
3433 DPCTLKernel_GetNumArgs,
35- DPCTLProgram_CreateFromOCLSource ,
36- DPCTLProgram_CreateFromSpirv ,
37- DPCTLProgram_Delete ,
38- DPCTLProgram_GetKernel ,
39- DPCTLProgram_HasKernel ,
34+ DPCTLKernelBundle_CreateFromOCLSource ,
35+ DPCTLKernelBundle_CreateFromSpirv ,
36+ DPCTLKernelBundle_Delete ,
37+ DPCTLKernelBundle_GetKernel ,
38+ DPCTLKernelBundle_HasKernel ,
4039 DPCTLSyclContextRef,
40+ DPCTLSyclDeviceRef,
41+ DPCTLSyclKernelBundleRef,
4142 DPCTLSyclKernelRef,
42- DPCTLSyclProgramRef,
4343)
4444
4545__all__ = [
@@ -51,8 +51,8 @@ __all__ = [
5151]
5252
5353cdef class SyclProgramCompilationError(Exception ):
54- """ This exception is raised when a ``sycl::program `` could not be built from
55- either a SPIR-V binary file or a string source.
54+ """ This exception is raised when a ``sycl::kernel_bundle `` could not be
55+ built from either a SPIR-V binary file or a string source.
5656 """
5757 pass
5858
@@ -61,20 +61,19 @@ cdef class SyclKernel:
6161 """
6262 """
6363 @staticmethod
64- cdef SyclKernel _create(DPCTLSyclKernelRef kref):
64+ cdef SyclKernel _create(DPCTLSyclKernelRef kref, str name ):
6565 cdef SyclKernel ret = SyclKernel.__new__ (SyclKernel)
6666 ret._kernel_ref = kref
67- ret._function_name = DPCTLKernel_GetFunctionName(kref)
67+ ret._function_name = name
6868 return ret
6969
7070 def __dealloc__ (self ):
7171 DPCTLKernel_Delete(self ._kernel_ref)
72- DPCTLCString_Delete(self ._function_name)
7372
7473 def get_function_name (self ):
7574 """ Returns the name of the ``sycl::kernel`` function.
7675 """
77- return self ._function_name.decode()
76+ return self ._function_name
7877
7978 def get_num_args (self ):
8079 """ Returns the number of arguments for this kernel function.
@@ -98,42 +97,45 @@ cdef class SyclKernel:
9897
9998
10099cdef class SyclProgram:
101- """ Wraps a ``sycl::program`` object created from an OpenCL interoperability
102- program.
100+ """ Wraps a ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object
101+ created using SYCL interoperability layer with underlying backends. Only the
102+ OpenCL and Level-Zero backends are currently supported.
103103
104- SyclProgram exposes the C API from ``dpctl_sycl_program_interface .h``. A
105- SyclProgram can be created from either a source string or a SPIR-V
106- binary file.
104+ SyclProgram exposes the C API from ``dpctl_sycl_kernel_bundle_interface .h``.
105+ A SyclProgram can be created from either a source string or a SPIR-V
106+ binary file.
107107 """
108108
109109 @staticmethod
110- cdef SyclProgram _create(DPCTLSyclProgramRef pref ):
110+ cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef ):
111111 cdef SyclProgram ret = SyclProgram.__new__ (SyclProgram)
112- ret._program_ref = pref
112+ ret._program_ref = KBRef
113113 return ret
114114
115115 def __dealloc__ (self ):
116- DPCTLProgram_Delete (self ._program_ref)
116+ DPCTLKernelBundle_Delete (self ._program_ref)
117117
118- cdef DPCTLSyclProgramRef get_program_ref(self ):
118+ cdef DPCTLSyclKernelBundleRef get_program_ref(self ):
119119 return self ._program_ref
120120
121121 cpdef SyclKernel get_sycl_kernel(self , str kernel_name):
122122 name = kernel_name.encode(' utf8' )
123- return SyclKernel._create(DPCTLProgram_GetKernel(self ._program_ref,
124- name))
123+ return SyclKernel._create(
124+ DPCTLKernelBundle_GetKernel(self ._program_ref, name),
125+ kernel_name
126+ )
125127
126128 def has_sycl_kernel (self , str kernel_name ):
127129 name = kernel_name.encode(' utf8' )
128- return DPCTLProgram_HasKernel (self ._program_ref, name)
130+ return DPCTLKernelBundle_HasKernel (self ._program_ref, name)
129131
130132 def addressof_ref (self ):
131- """ Returns the address of the C API DPCTLSyclProgramRef pointer
133+ """ Returns the address of the C API DPCTLSyclKernelBundleRef pointer
132134 as a long.
133135
134136 Returns:
135- The address of the ``DPCTLSyclProgramRef `` pointer used to create
136- this :class:`dpctl.SyclProgram` object cast to a ``size_t``.
137+ The address of the ``DPCTLSyclKernelBundleRef `` pointer used to
138+ create this :class:`dpctl.SyclProgram` object cast to a ``size_t``.
137139 """
138140 return int (< size_t> self ._program_ref)
139141
@@ -142,9 +144,10 @@ cpdef create_program_from_source(SyclQueue q, unicode src, unicode copts=""):
142144 """
143145 Creates a Sycl interoperability program from an OpenCL source string.
144146
145- We use the ``DPCTLProgram_CreateFromOCLSource()`` C API function to
146- create a ``sycl::program`` from an OpenCL source program that can
147- contain multiple kernels. Note currently only supported for OpenCL.
147+ We use the ``DPCTLKernelBundle_CreateFromOCLSource()`` C API function
148+ to create a ``sycl::kernel_bundle<sycl::bundle_state::executable>``
149+ from an OpenCL source program that can contain multiple kernels.
150+ Note: This function is currently only supported for the OpenCL backend.
148151
149152 Parameters:
150153 q (SyclQueue) : The :class:`SyclQueue` for which the
@@ -155,33 +158,37 @@ cpdef create_program_from_source(SyclQueue q, unicode src, unicode copts=""):
155158
156159 Returns:
157160 program (SyclProgram): A :class:`SyclProgram` object wrapping the
158- ``sycl::program`` returned by the C API.
161+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` returned
162+ by the C API.
159163
160164 Raises:
161- SyclProgramCompilationError: If a SYCL program could not be created.
165+ SyclProgramCompilationError: If a SYCL kernel bundle could not be
166+ created.
162167 """
163168
164- cdef DPCTLSyclProgramRef Pref
169+ cdef DPCTLSyclKernelBundleRef KBref
165170 cdef bytes bSrc = src.encode(' utf8' )
166171 cdef bytes bCOpts = copts.encode(' utf8' )
167172 cdef const char * Src = < const char * > bSrc
168173 cdef const char * COpts = < const char * > bCOpts
169174 cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
170- Pref = DPCTLProgram_CreateFromOCLSource(CRef, Src, COpts)
175+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
176+ KBref = DPCTLKernelBundle_CreateFromOCLSource(CRef, DRef, Src, COpts)
171177
172- if Pref is NULL :
178+ if KBref is NULL :
173179 raise SyclProgramCompilationError()
174180
175- return SyclProgram._create(Pref )
181+ return SyclProgram._create(KBref )
176182
177183
178184cpdef create_program_from_spirv(SyclQueue q, const unsigned char [:] IL,
179185 unicode copts = " " ):
180186 """
181187 Creates a Sycl interoperability program from an SPIR-V binary.
182188
183- We use the ``DPCTLProgram_CreateFromOCLSpirv()`` C API function to
184- create a ``sycl::program`` object from an compiled SPIR-V binary file.
189+ We use the ``DPCTLKernelBundle_CreateFromOCLSpirv()`` C API function to
190+ create a ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object
191+ from an compiled SPIR-V binary file.
185192
186193 Parameters:
187194 q (SyclQueue): The :class:`SyclQueue` for which the
@@ -192,20 +199,25 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
192199
193200 Returns:
194201 program (SyclProgram): A :class:`SyclProgram` object wrapping the
195- ``sycl::program`` returned by the C API.
202+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` returned by
203+ the C API.
196204
197205 Raises:
198- SyclProgramCompilationError: If a SYCL program could not be created.
206+ SyclProgramCompilationError: If a SYCL kernel bundle could not be
207+ created.
199208 """
200209
201- cdef DPCTLSyclProgramRef Pref
210+ cdef DPCTLSyclKernelBundleRef KBref
202211 cdef const unsigned char * dIL = & IL[0 ]
203212 cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
213+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
204214 cdef size_t length = IL.shape[0 ]
205215 cdef bytes bCOpts = copts.encode(' utf8' )
206216 cdef const char * COpts = < const char * > bCOpts
207- Pref = DPCTLProgram_CreateFromSpirv(CRef, < const void * > dIL, length, COpts)
208- if Pref is NULL :
217+ KBref = DPCTLKernelBundle_CreateFromSpirv(
218+ CRef, DRef, < const void * > dIL, length, COpts
219+ )
220+ if KBref is NULL :
209221 raise SyclProgramCompilationError()
210222
211- return SyclProgram._create(Pref )
223+ return SyclProgram._create(KBref )
0 commit comments