Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pufferlib/ocean/environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import importlib
import pufferlib.emulation
import platform
import os

is_windows = platform.system() == "Windows"
if is_windows:
raylib_dll_path = os.path.abspath(r'raylib-5.5_win64_msvc16\lib')
os.add_dll_directory(raylib_dll_path)

def lazy_import(module_path, attr):
"""
Expand Down
7 changes: 6 additions & 1 deletion pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import warnings
warnings.filterwarnings('error', category=RuntimeWarning)

import platform
is_windows = platform.system() == "Windows"
import os
import io
import sys
Expand Down Expand Up @@ -616,7 +618,10 @@ def run(self):
while not self.stopped:
self.cpu_util.append(100*psutil.cpu_percent()/psutil.cpu_count())
mem = psutil.virtual_memory()
self.cpu_mem.append(100*mem.active/mem.total)
if is_windows:
self.cpu_mem.append(100*mem.used/mem.total)
else:
self.cpu_mem.append(100*mem.active/mem.total)
if torch.cuda.is_available():
# Monitoring in distributed crashes nvml
if torch.distributed.is_initialized():
Expand Down
43 changes: 26 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@

# Build raylib for your platform
RAYLIB_URL = 'https://github.com/raysan5/raylib/releases/download/5.5/'
RAYLIB_NAME = 'raylib-5.5_macos' if platform.system() == "Darwin" else 'raylib-5.5_linux_amd64'
system = platform.system()
if system == 'Linux':
RAYLIB_NAME = 'raylib-5.5_linux_amd64'
elif system == 'Darwin':
RAYLIB_NAME = 'raylib-5.5_macos'
elif system == 'Windows':
RAYLIB_NAME = 'raylib-5.5_win64_msvc16'

RLIGHTS_URL = 'https://raw.githubusercontent.com/raysan5/raylib/refs/heads/master/examples/shaders/rlights.h'

def download_raylib(platform, ext):
Expand All @@ -53,7 +60,7 @@ def download_raylib(platform, ext):

if not NO_OCEAN:
download_raylib('raylib-5.5_webassembly', '.zip')
download_raylib(RAYLIB_NAME, '.tar.gz')
download_raylib(RAYLIB_NAME, '.tar.gz' if platform.system() != "Windows" else '.zip')

BOX2D_URL = 'https://github.com/capnspacehook/box2d/releases/latest/download/'
BOX2D_NAME = 'box2d-macos-arm64' if platform.system() == "Darwin" else 'box2d-linux-amd64'
Expand Down Expand Up @@ -124,7 +131,6 @@ def download_box2d(platform):
'-O3',
]

system = platform.system()
if system == 'Linux':
extra_compile_args += [
'-Wno-alloc-size-larger-than',
Expand All @@ -145,6 +151,8 @@ def download_box2d(platform):
'-framework', 'OpenGL',
'-framework', 'IOKit',
]
elif system == 'Windows':
pass
else:
raise ValueError(f'Unsupported system: {system}')

Expand Down Expand Up @@ -183,8 +191,8 @@ def run(self):
self.extensions = [e for e in self.extensions if e.name in extnames]
super().run()

RAYLIB_A = f'{RAYLIB_NAME}/lib/raylibdll.lib' if system == "Windows" else f'{RAYLIB_NAME}/lib/libraylib.a'
INCLUDE = [f'{BOX2D_NAME}/include', f'{BOX2D_NAME}/src']
RAYLIB_A = f'{RAYLIB_NAME}/lib/libraylib.a'
extension_kwargs = dict(
include_dirs=INCLUDE,
extra_compile_args=extra_compile_args,
Expand All @@ -198,7 +206,7 @@ def run(self):
c_extension_paths = glob.glob('pufferlib/ocean/**/binding.c', recursive=True)
c_extensions = [
Extension(
path.rstrip('.c').replace('/', '.'),
path.rstrip('.c').replace('/', '.').replace('\\', '.'),
sources=[path],
**extension_kwargs,
)
Expand Down Expand Up @@ -271,18 +279,19 @@ def run(self):
),
]

# Prevent Conda from injecting garbage compile flags
from distutils.sysconfig import get_config_vars
cfg_vars = get_config_vars()
for key in ('CC', 'CXX', 'LDSHARED'):
if cfg_vars[key]:
cfg_vars[key] = cfg_vars[key].replace('-B /root/anaconda3/compiler_compat', '')
cfg_vars[key] = cfg_vars[key].replace('-pthread', '')
cfg_vars[key] = cfg_vars[key].replace('-fno-strict-overflow', '')
if system != 'Windows':
# Prevent Conda from injecting garbage compile flags
from distutils.sysconfig import get_config_vars
cfg_vars = get_config_vars()
for key in ('CC', 'CXX', 'LDSHARED'):
if cfg_vars[key]:
cfg_vars[key] = cfg_vars[key].replace('-B /root/anaconda3/compiler_compat', '')
cfg_vars[key] = cfg_vars[key].replace('-pthread', '')
cfg_vars[key] = cfg_vars[key].replace('-fno-strict-overflow', '')

for key, value in cfg_vars.items():
if value and '-fno-strict-overflow' in str(value):
cfg_vars[key] = value.replace('-fno-strict-overflow', '')
for key, value in cfg_vars.items():
if value and '-fno-strict-overflow' in str(value):
cfg_vars[key] = value.replace('-fno-strict-overflow', '')

install_requires = [
'setuptools',
Expand All @@ -295,7 +304,7 @@ def run(self):

if not NO_TRAIN:
install_requires += [
'torch>=2.9',
'torch>=2.6',
'psutil',
'nvidia-ml-py',
'rich',
Expand Down