diff --git a/pufferlib/config/atari.ini b/pufferlib/config/atari.ini index 75d1557ab..be9b54e1d 100644 --- a/pufferlib/config/atari.ini +++ b/pufferlib/config/atari.ini @@ -8,7 +8,7 @@ rnn_name = Recurrent [vec] num_envs = 128 -num_workers = 16 +num_workers = auto batch_size = 64 [train] diff --git a/pufferlib/environments/atari/environment.py b/pufferlib/environments/atari/environment.py index 92295ecaa..6af5c5457 100644 --- a/pufferlib/environments/atari/environment.py +++ b/pufferlib/environments/atari/environment.py @@ -16,7 +16,7 @@ def make(name, obs_type='grayscale', frameskip=4, repeat_action_probability=0.0, render_mode='rgb_array', buf=None, seed=0): '''Atari creation function''' - pufferlib.environments.try_import('ale_py', 'AtariEnv') + pufferlib.environments.try_import('ale_py') ale_render_mode = render_mode if render_mode == 'human': diff --git a/pufferlib/vector.py b/pufferlib/vector.py index 78614f4d6..393237964 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -640,7 +640,8 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer if 'num_workers' in kwargs: if kwargs['num_workers'] == 'auto': - kwargs['num_workers'] = num_envs + import psutil + kwargs['num_workers'] = min(psutil.cpu_count(logical=False), num_envs) # TODO: None? envs_per_worker = num_envs / kwargs['num_workers'] diff --git a/pyproject.toml b/pyproject.toml index 797707449..e6562b16b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ avalon = [ atari = [ 'gymnasium[accept-rom-license]==0.29.1', 'opencv-python==3.4.17.63', - 'ale_py==0.9.0', + 'ale_py==0.10.1', ] box2d = [