Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ def get_kernel_string(kernel_source, params=None):
kernel_string = read_file(kernel_source)
elif isinstance(kernel_source, str):
if looks_like_a_filename(kernel_source):
kernel_string = read_file(kernel_source) or kernel_source
with open(kernel_source, "r") as f:
kernel_string = f.read()
else:
kernel_string = kernel_source
else:
Expand Down Expand Up @@ -1115,6 +1116,10 @@ def compile_restrictions(
noncompiled_restrictions.append((r, [], r))
return noncompiled_restrictions + compiled_restrictions

def check_matching_problem_size(cached_problem_size, problem_size):
"""Check the if requested problem size matches the problem size in the cache."""
if not (np.array(cached_problem_size) == np.array(problem_size)).all():
raise ValueError(f"Cannot load cache which contains results for different problem_size, cache: {cached_problem_size}, requested: {problem_size}")

def process_cache(cache, kernel_options, tuning_options, runner):
"""Cache file for storing tuned configurations.
Expand Down Expand Up @@ -1185,18 +1190,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
f"Cannot load cache which contains results for different kernel (cache: {cached_data['kernel_name']}, actual: {kernel_options.kernel_name})"
)
if "problem_size" in cached_data and not callable(kernel_options.problem_size):
# if it's a single value, convert to an array
if isinstance(cached_data["problem_size"], int):
cached_data["problem_size"] = [cached_data["problem_size"]]
# if problem_size is not iterable, compare directly
if not hasattr(kernel_options.problem_size, "__iter__"):
if cached_data["problem_size"] != kernel_options.problem_size:
raise ValueError("Cannot load cache which contains results for different problem_size")
# else (problem_size is iterable)
# cache returns list, problem_size is likely a tuple. Therefore, the next check
# checks the equality of all items in the list/tuples individually
elif not all([i == j for i, j in zip(cached_data["problem_size"], kernel_options.problem_size)]):
raise ValueError("Cannot load cache which contains results for different problem_size")
check_matching_problem_size(cached_data["problem_size"], kernel_options.problem_size)
if cached_data["tune_params_keys"] != list(tuning_options.tune_params.keys()):
if all(key in tuning_options.tune_params for key in cached_data["tune_params_keys"]):
raise ValueError(
Expand Down
24 changes: 21 additions & 3 deletions test/test_util_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,10 @@ def gen_kernel(params):

def test_get_kernel_string_filename_not_found():
# when the string looks like a filename, but the file does not exist
# assume the string is not a filename after all
# check if throws an exception
bogus_filename = "filename_3456789.cu"
answer = get_kernel_string(bogus_filename)
assert answer == bogus_filename
with pytest.raises(FileNotFoundError):
get_kernel_string(bogus_filename)


def test_looks_like_a_filename1():
Expand Down Expand Up @@ -726,6 +726,24 @@ def test_parse_restrictions():
assert all(param in tune_params for param in params)


def test_check_matching_problem_size():
# these should error
with pytest.raises(ValueError):
check_matching_problem_size(42, 1000)
with pytest.raises(ValueError):
check_matching_problem_size([42,1], 42)
with pytest.raises(ValueError):
check_matching_problem_size([42,0], 42)
with pytest.raises(ValueError):
check_matching_problem_size(None, 42)
# these should not error
check_matching_problem_size(1000, (1000,))
check_matching_problem_size([1000], 1000)
check_matching_problem_size(1000, 1000)
check_matching_problem_size(1000, [1000])
check_matching_problem_size([1000,], 1000)


def test_convert_constraint_lambdas():

restrictions = [lambda p: 32 <= p["block_size_x"]*p["block_size_y"] <= 1024,
Expand Down