File tree Expand file tree Collapse file tree 2 files changed +12
-7
lines changed
Expand file tree Collapse file tree 2 files changed +12
-7
lines changed Original file line number Diff line number Diff line change 3232import contextlib
3333import contextvars
3434import operator
35- import os
3635from numbers import Number
3736
3837import mkl
@@ -101,7 +100,8 @@ def _workers_to_num_threads(w):
101100 if _w == 0 :
102101 raise ValueError ("Number of workers must not be zero" )
103102 if _w < 0 :
104- _cpu_count = os .cpu_count ()
103+ # SciPy uses os.cpu_count()
104+ _cpu_count = mkl .get_max_threads () # pylint: disable=no-member
105105 _w += _cpu_count + 1
106106 if _w <= 0 :
107107 raise ValueError (
Original file line number Diff line number Diff line change 44import multiprocessing
55import os
66
7+ import mkl
78import numpy as np
89import pytest
910from numpy .testing import assert_allclose
@@ -80,22 +81,26 @@ def test_invalid_workers(x):
8081
8182
8283def test_set_get_workers ():
83- cpus = os .cpu_count ()
84+ # cpus = os.cpu_count()
85+ threads = mkl .get_max_threads () # pylint: disable=no-member
86+ # cpus and threads are usually the same but in CI, cpus = 4 and threads = 2
87+ # SciPy uses `os.cpu_count()` to get the number of workers, while
88+ # `mkl_fft.interfaces.scipy_fft` uses `mkl.get_max_threads()`
8489
8590 # default value is max number of threads unlike stock SciPy
86- assert fft .get_workers () == cpus
91+ assert fft .get_workers () == threads
8792 with fft .set_workers (4 ):
8893 assert fft .get_workers () == 4
8994
9095 with fft .set_workers (- 1 ):
91- assert fft .get_workers () == cpus
96+ assert fft .get_workers () == threads
9297
9398 assert fft .get_workers () == 4
9499
95100 # default value is max number of threads unlike stock SciPy
96- assert fft .get_workers () == cpus
101+ assert fft .get_workers () == threads
97102
98- with fft .set_workers (- cpus ):
103+ with fft .set_workers (- threads ):
99104 assert fft .get_workers () == 1
100105
101106
You can’t perform that action at this time.
0 commit comments