Skip to content
44 changes: 44 additions & 0 deletions examples/cuda/vector_add_custom_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python
"""This is the minimal example from the README"""

import numpy
import kernel_tuner
from kernel_tuner import tune_kernel
from kernel_tuner.file_utils import store_output_file, store_metadata_file

def tune():

kernel_string = """
__global__ void vector_add(float *c, float *a, float *b, int n) {
int i = blockIdx.x * block_size_x + threadIdx.x;
if (i<n) {
c[i] = a[i] + b[i];
}
}
"""

size = 10000000

a = numpy.random.randn(size).astype(numpy.float32)
b = numpy.random.randn(size).astype(numpy.float32)
c = numpy.zeros_like(b)
n = numpy.int32(size)

args = [c, a, b, n]

tune_params = dict()
tune_params["block_size_x"] = [128+64*i for i in range(15)]

results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, strategy=kernel_tuner.strategies.minimize, verbose=True)

# Store the tuning results in an output file
store_output_file("vector_add.json", results, tune_params)

# Store the metadata of this run
store_metadata_file("vector_add-metadata.json")

return results


if __name__ == "__main__":
tune()
32 changes: 9 additions & 23 deletions kernel_tuner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,29 +621,15 @@ def tune_kernel(
if strategy in strategy_map:
strategy = strategy_map[strategy]
else:
raise ValueError(f"Unkown strategy {strategy}, must be one of: {', '.join(list(strategy_map.keys()))}")

# make strategy_options into an Options object
if tuning_options.strategy_options:
if not isinstance(strategy_options, Options):
tuning_options.strategy_options = Options(strategy_options)

# select strategy based on user options
if "fraction" in tuning_options.strategy_options and not tuning_options.strategy == "random_sample":
raise ValueError(
'It is not possible to use fraction in combination with strategies other than "random_sample". '
'Please set strategy="random_sample", when using "fraction" in strategy_options'
)

# check if method is supported by the selected strategy
if "method" in tuning_options.strategy_options:
method = tuning_options.strategy_options.method
if method not in strategy.supported_methods:
raise ValueError("Method %s is not supported for strategy %s" % (method, tuning_options.strategy))

# if no strategy_options dict has been passed, create empty dictionary
else:
tuning_options.strategy_options = Options({})
# check for user-defined strategy
if hasattr(strategy, "tune") and callable(strategy.tune):
# user-defined strategy
pass
else:
raise ValueError(f"Unkown strategy {strategy}, must be one of: {', '.join(list(strategy_map.keys()))}")

# ensure strategy_options is an Options object
tuning_options.strategy_options = Options(strategy_options or {})

# if no strategy selected
else:
Expand Down
3 changes: 2 additions & 1 deletion kernel_tuner/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, searchspace: Searchspace, tuning_options, runner, *, scaling=
self.scaling = scaling
self.searchspace = searchspace
self.results = []
self.budget_spent_fraction = 0.0

def __call__(self, x, check_restrictions=True):
"""Cost function used by almost all strategies."""
Expand All @@ -70,7 +71,7 @@ def __call__(self, x, check_restrictions=True):
logging.debug('x: ' + str(x))

# check if max_fevals is reached or time limit is exceeded
util.check_stop_criterion(self.tuning_options)
self.budget_spent_fraction = util.check_stop_criterion(self.tuning_options)

# snap values in x to nearest actual value for each parameter, unscale x if needed
if self.snap:
Expand Down
49 changes: 49 additions & 0 deletions kernel_tuner/strategies/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Wrapper intended for user-defined custom optimization methods"""

from abc import ABC, abstractmethod

from kernel_tuner import util
from kernel_tuner.searchspace import Searchspace
from kernel_tuner.strategies.common import CostFunc


class OptAlg(ABC):
"""Base class for user-defined optimization algorithms."""

def __init__(self):
self.costfunc_kwargs = {"scaling": True, "snap": True}

@abstractmethod
def __call__(self, func: CostFunc, searchspace: Searchspace) -> tuple[tuple, float]:
"""Optimize the black box function `func` within the given `searchspace`.

Args:
func (CostFunc): Cost function to be optimized. Has a property `budget_spent_fraction` that indicates how much of the budget has been spent.
searchspace (Searchspace): Search space containing the parameters to be optimized.

Returns:
tuple[tuple, float]: tuple of the best parameters and the corresponding cost value
"""
pass


class OptAlgWrapper:
"""Wrapper class for user-defined optimization algorithms"""

def __init__(self, optimizer: OptAlg):
self.optimizer: OptAlg = optimizer

def tune(self, searchspace: Searchspace, runner, tuning_options):
cost_func = CostFunc(searchspace, tuning_options, runner, **self.optimizer.costfunc_kwargs)

if self.optimizer.costfunc_kwargs.get('scaling', True):
# Initialize costfunc for scaling
cost_func.get_bounds_x0_eps()

try:
self.optimizer(cost_func, searchspace)
except util.StopCriterionReached as e:
if tuning_options.verbose:
print(e)

return cost_func.results
28 changes: 22 additions & 6 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,28 @@ def check_argument_list(kernel_name, kernel_string, args):
warnings.warn(errors[0], UserWarning)


def check_stop_criterion(to):
"""Checks if max_fevals is reached or time limit is exceeded."""
if "max_fevals" in to and len(to.unique_results) >= to.max_fevals:
raise StopCriterionReached("max_fevals reached")
if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit):
raise StopCriterionReached("time limit exceeded")
def check_stop_criterion(to: dict) -> float:
"""Check if the stop criterion is reached.

Args:
to (dict): tuning options.

Raises:
StopCriterionReached: if the max_fevals is reached or time limit is exceeded.

Returns:
float: fraction of budget spent.
"""
if "max_fevals" in to:
if len(to.unique_results) >= to.max_fevals:
raise StopCriterionReached(f"max_fevals ({to.max_fevals}) reached")
return len(to.unique_results) / to.max_fevals
if "time_limit" in to:
time_spent = (time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)
if time_spent > to.time_limit:
raise StopCriterionReached("time limit exceeded")
return time_spent / to.time_limit



def check_tune_params_list(tune_params, observers, simulation_mode=False):
Expand Down
148 changes: 148 additions & 0 deletions test/test_custom_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

### The following was generating using the LLaMEA prompt and OpenAI o1

import numpy as np

from kernel_tuner.strategies.wrapper import OptAlg

class HybridDELocalRefinement(OptAlg):
"""
A two-phase differential evolution with local refinement, intended for BBOB-type
black box optimization problems in [-5,5]^dim.

One-line idea: A two-phase hybrid DE with local refinement that balances global
exploration and local exploitation under a strict function evaluation budget.
"""

def __init__(self):
super().__init__()
# You can adjust these hyperparameters based on experimentation/tuning:
self.F = 0.8 # Differential weight
self.CR = 0.9 # Crossover probability
self.local_search_freq = 10 # Local refinement frequency in generations

def __call__(self, func, searchspace):
"""
Optimize the black box function `func` in [-5,5]^dim, using
at most self.budget function evaluations.

Returns:
best_params: np.ndarray representing the best parameters found
best_value: float representing the best objective value found
"""
self.dim = searchspace.num_params
self.population_size = round(min(min(50, 10 * self.dim), np.ceil(searchspace.size / 3))) # Caps for extremely large dim

# 1. Initialize population
lower_bound, upper_bound = -5.0, 5.0
pop = np.random.uniform(lower_bound, upper_bound, (self.population_size, self.dim))

# Evaluate initial population
evaluations = 0
fitness = np.empty(self.population_size)
for i in range(self.population_size):
fitness[i] = func(pop[i])
evaluations += 1

# Track best solution
best_idx = np.argmin(fitness)
best_params = pop[best_idx].copy()
best_value = fitness[best_idx]

# 2. Main evolutionary loop
gen = 0
while func.budget_spent_fraction < 1.0 and evaluations < searchspace.size:
gen += 1
for i in range(self.population_size):
# DE mutation: pick three distinct indices
idxs = np.random.choice(self.population_size, 3, replace=False)
a, b, c = pop[idxs]
mutant = a + self.F * (b - c)

# Crossover
trial = np.copy(pop[i])
crossover_points = np.random.rand(self.dim) < self.CR
trial[crossover_points] = mutant[crossover_points]

# Enforce bounds
trial = np.clip(trial, lower_bound, upper_bound)

# Evaluate trial
trial_fitness = func(trial)
evaluations += 1
if func.budget_spent_fraction > 1.0:
# If out of budget, wrap up
if trial_fitness < fitness[i]:
pop[i] = trial
fitness[i] = trial_fitness
# Update global best
if trial_fitness < best_value:
best_value = trial_fitness
best_params = trial.copy()
break

# Selection
if trial_fitness < fitness[i]:
pop[i] = trial
fitness[i] = trial_fitness
# Update global best
if trial_fitness < best_value:
best_value = trial_fitness
best_params = trial.copy()

# Periodically refine best solution with a small local neighborhood search
if gen % self.local_search_freq == 0 and func.budget_spent_fraction < 1.0:
best_params, best_value, evaluations = self._local_refinement(
func, best_params, best_value, evaluations, lower_bound, upper_bound
)

return best_params, best_value

def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub):
"""
Local refinement around the best solution found so far.
Uses a quick 'perturb-and-accept' approach in a shrinking neighborhood.
"""
# Neighborhood size shrinks as the budget is consumed
step_size = 0.2 * (1.0 - func.budget_spent_fraction)

for _ in range(5): # 5 refinements each time
if func.budget_spent_fraction >= 1.0:
break
candidate = best_params + np.random.uniform(-step_size, step_size, self.dim)
candidate = np.clip(candidate, lb, ub)
cand_value = func(candidate)
evaluations += 1
if cand_value < best_value:
best_value = cand_value
best_params = candidate.copy()

return best_params, best_value, evaluations




### Testing the Optimization Algorithm Wrapper in Kernel Tuner
import os
from kernel_tuner import tune_kernel
from kernel_tuner.strategies.wrapper import OptAlgWrapper

from .test_runners import env

cache_filename = os.path.dirname(os.path.realpath(__file__)) + "/test_cache_file.json"

def test_OptAlgWrapper(env):
kernel_name, kernel_string, size, args, tune_params = env

# Instantiate LLaMAE optimization algorithm
optimizer = HybridDELocalRefinement()

# Wrap the algorithm class in the OptAlgWrapper
# for use in Kernel Tuner
strategy = OptAlgWrapper(optimizer)
strategy_options = { 'max_fevals': 15 }

# Call the tuner
tune_kernel(kernel_name, kernel_string, size, args, tune_params,
strategy=strategy, strategy_options=strategy_options, cache=cache_filename,
simulation_mode=True, verbose=True)