diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 072cce43..028cd192 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -437,16 +437,24 @@ def compact_number(v): def get_grid_dimensions(current_problem_size, params, grid_div, block_size_names): """Compute grid dims based on problem sizes and listed grid divisors.""" - def get_dimension_divisor(divisor_list, default, params): - if divisor_list is None: - if default in params: - divisor_list = [default] - else: - return 1 - if callable(divisor_list): - return divisor_list(params) + def get_dimension_divisor(divisor, default, params): + divisor_num = 1 + + if divisor is None: + divisor_num = params.get(default, 1) + elif isinstance(divisor, int): + divisor_num = divisor + elif callable(divisor): + divisor_num = divisor(params) + elif isinstance(divisor, str): + divisor_num = int(eval(replace_param_occurrences(divisor, params))) + elif np.iterable(divisor): + for div in divisor: + divisor_num *= get_dimension_divisor(div, 1, params) else: - return np.prod([int(eval(replace_param_occurrences(s, params))) for s in divisor_list]) + raise ValueError("Error: unrecognized type in grid divisor list, should be any of int, str, callable, or iterable") + + return divisor_num divisors = [get_dimension_divisor(d, block_size_names[i], params) for i, d in enumerate(grid_div)] return tuple(int(np.ceil(float(current_problem_size[i]) / float(d))) for i, d in enumerate(divisors)) diff --git a/test/test_util_functions.py b/test/test_util_functions.py index f3431991..d6870d9d 100644 --- a/test/test_util_functions.py +++ b/test/test_util_functions.py @@ -59,6 +59,22 @@ def test_get_grid_dimensions1(): assert grid[1] == 25 assert grid[2] == 1 + grid = get_grid_dimensions( + problem_size, params, ("41", 37, None), block_size_names + ) + + assert grid[0] == 25 + assert grid[1] == 28 + assert grid[2] == 1 + + grid = get_grid_dimensions( + problem_size, params, (None, [2, "block_y"], None), block_size_names + ) + + assert grid[0] == 1024 + assert grid[1] == 14 + assert grid[2] == 1 + def test_get_grid_dimensions2(): problem_size = (1024, 1024, 1)