Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,24 @@ def compact_number(v):
def get_grid_dimensions(current_problem_size, params, grid_div, block_size_names):
"""Compute grid dims based on problem sizes and listed grid divisors."""

def get_dimension_divisor(divisor_list, default, params):
if divisor_list is None:
if default in params:
divisor_list = [default]
else:
return 1
if callable(divisor_list):
return divisor_list(params)
def get_dimension_divisor(divisor, default, params):
divisor_num = 1

if divisor is None:
divisor_num = params.get(default, 1)
elif isinstance(divisor, int):
divisor_num = divisor
elif callable(divisor):
divisor_num = divisor(params)
elif isinstance(divisor, str):
divisor_num = int(eval(replace_param_occurrences(divisor, params)))
elif np.iterable(divisor):
for div in divisor:
divisor_num *= get_dimension_divisor(div, 1, params)
else:
return np.prod([int(eval(replace_param_occurrences(s, params))) for s in divisor_list])
raise ValueError("Error: unrecognized type in grid divisor list, should be any of int, str, callable, or iterable")

return divisor_num

divisors = [get_dimension_divisor(d, block_size_names[i], params) for i, d in enumerate(grid_div)]
return tuple(int(np.ceil(float(current_problem_size[i]) / float(d))) for i, d in enumerate(divisors))
Expand Down
16 changes: 16 additions & 0 deletions test/test_util_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def test_get_grid_dimensions1():
assert grid[1] == 25
assert grid[2] == 1

grid = get_grid_dimensions(
problem_size, params, ("41", 37, None), block_size_names
)

assert grid[0] == 25
assert grid[1] == 28
assert grid[2] == 1

grid = get_grid_dimensions(
problem_size, params, (None, [2, "block_y"], None), block_size_names
)

assert grid[0] == 1024
assert grid[1] == 14
assert grid[2] == 1


def test_get_grid_dimensions2():
problem_size = (1024, 1024, 1)
Expand Down