Skip to content

Commit 08e57d9

Browse files
authored
Merge pull request #9 from mortacious/feature/improved_installation
Improved setup by packaging the optix headers only at wheel creation
2 parents 0807ffb + 0982849 commit 08e57d9

File tree

6 files changed

+55
-9
lines changed

6 files changed

+55
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ imgui.ini
1010
*.cpp
1111
*.c
1212
.*
13+
include/

optix/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.2"
1+
__version__ = "0.1.3"

optix/denoiser.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ cdef class Denoiser(OptixContextObject):
319319
denoise_alpha = DenoiserAlphaMode.COPY
320320

321321
assert isinstance(denoise_alpha, DenoiserAlphaMode), "Optix >7.5 changed this from a boolean variable into an enum"
322-
params.denoiseAlpha = denoise_alpha.value
322+
params.denoiseAlpha = <OptixDenoiserAlphaMode>denoise_alpha.value
323323
ELSE:
324324
params.denoiseAlpha = 1 if denoise_alpha else 0
325325

optix/module.pyx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from enum import IntEnum, IntFlag
44
import os
5-
from .path_utility import get_cuda_include_path, get_optix_include_path
5+
import warnings
6+
from .path_utility import get_cuda_include_path, get_optix_include_path, get_local_optix_include_path
67
from .common cimport optix_check_return, optix_init
78
from .context cimport DeviceContext
89
from .pipeline cimport PipelineCompileOptions
@@ -408,9 +409,11 @@ cdef class Module(OptixContextObject):
408409
flags = list(compile_flags)
409410
# get cuda and optix_include_paths
410411
cuda_include_path = get_cuda_include_path()
411-
print("cuda path", cuda_include_path)
412-
optix_include_path = get_optix_include_path()
413-
412+
optix_include_path = get_local_optix_include_path()
413+
if not os.path.exists(optix_include_path):
414+
warnings.warn("Local optix not found. This usually indicates some installation issue. Attempting"
415+
" to load the global optix includes instead.", RuntimeWarning)
416+
optix_include_path = get_optix_include_path()
414417
flags.extend([f'-I{cuda_include_path}', f'-I{optix_include_path}'])
415418
ptx, _ = prog.compile(flags)
416419
return ptx

optix/path_utility.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import os
2727
from itertools import chain
28+
import pathlib
2829

2930
_cuda_path_cache = 'NOT_INITIALIZED'
3031
_optix_path_cache = 'NOT_INITIALIZED'
@@ -98,7 +99,7 @@ def get_optix_path(path_hint=None, environment_variable=None):
9899
None else None)
99100
if optix_header_path is None:
100101
# search on the default path
101-
optix_header_path = search_on_path(('../optix/include/optix.h',), keys=('PATH',))
102+
optix_header_path = search_on_path(('../optix/include/optix.h',), keys=('PATH', 'OPTIX_PATH'))
102103

103104
if optix_header_path is not None:
104105
optix_header_path = os.path.normpath(os.path.join(os.path.dirname(optix_header_path), '..'))
@@ -115,6 +116,10 @@ def get_optix_path(path_hint=None, environment_variable=None):
115116
return _optix_path_cache
116117

117118

119+
def get_local_optix_include_path():
120+
local_include_path = pathlib.Path(__file__).parent / "include"
121+
return str(local_include_path) if local_include_path.exists() else None
122+
118123
def get_optix_include_path(environment_variable=None):
119124
optix_path = get_optix_path(environment_variable=environment_variable)
120125
if optix_path is None:
@@ -124,3 +129,4 @@ def get_optix_include_path(environment_variable=None):
124129
return optix_include_path
125130
else:
126131
return None
132+

setup.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from setuptools import setup, Extension, find_packages
1+
from struct import pack
2+
from setuptools import setup, Extension, find_packages, find_namespace_packages
23
from Cython.Build import cythonize
34
import re
45
import os
56
from pathlib import Path
7+
import shutil
68

79

810
# standalone import of a module (https://stackoverflow.com/a/58423785)
@@ -29,6 +31,8 @@ def import_module_from_path(path):
2931
util = import_module_from_path('optix/path_utility.py')
3032
cuda_include_path = util.get_cuda_include_path()
3133
optix_include_path = util.get_optix_include_path()
34+
print("Found cuda includes at", cuda_include_path)
35+
print("Found optix includes at", optix_include_path)
3236
if cuda_include_path is None or optix_include_path is None:
3337
raise RuntimeError("Cuda or optix not found in the system")
3438

@@ -65,6 +69,35 @@ def import_module_from_path(path):
6569

6670
version = import_module_from_path('optix/_version.py').__version__
6771

72+
package_data = {}
73+
74+
try:
75+
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
76+
77+
def glob_fix(package_name, glob):
78+
# this assumes setup.py lives in the folder that contains the package
79+
package_path = Path(f'./{package_name}').resolve()
80+
return [str(path.relative_to(package_path))
81+
for path in package_path.glob(glob)]
82+
83+
class custom_bdist_wheel(_bdist_wheel):
84+
def finalize_options(self):
85+
_bdist_wheel.finalize_options(self)
86+
87+
# create the path for the internal headers
88+
# due to optix license restrictions those headers
89+
# cannot be distributed on pypi directly so we will add this headers dynamically
90+
# upon wheel construction to install them alongside the package
91+
92+
if not os.path.exists('optix/include/optix.h'):
93+
shutil.copytree(optix_include_path, 'optix/include')
94+
self.distribution.package_data.update({
95+
'optix': [*glob_fix('optix', 'include/**/*')]
96+
})
97+
98+
except ImportError:
99+
custom_bdist_wheel = None
100+
68101
setup(
69102
name="python-optix",
70103
version=version,
@@ -87,6 +120,7 @@ def import_module_from_path(path):
87120
classifiers=[
88121
"Programming Language :: Python :: 3.8",
89122
"Programming Language :: Python :: 3.9",
123+
"Programming Language :: Python :: 3.10",
90124
"License :: OSI Approved :: MIT License",
91125
"Operating System :: POSIX :: Linux",
92126
"Operating System :: Microsoft :: Windows",
@@ -101,5 +135,7 @@ def import_module_from_path(path):
101135
'examples': ["pillow", "pyopengl", "pyglfw", "pyimgui"]
102136
},
103137
python_requires=">=3.8",
104-
zip_safe=False
138+
package_data=package_data,
139+
zip_safe=False,
140+
cmdclass={'bdist_wheel': custom_bdist_wheel}
105141
)

0 commit comments

Comments
 (0)