From 15175655e42c4053a34b4d41cd1724fb750fbc9b Mon Sep 17 00:00:00 2001 From: Ben van Werkhoven Date: Fri, 7 Mar 2025 16:50:41 +0100 Subject: [PATCH 1/5] add support for user-defined optimization algorithms --- examples/cuda/vector_add_custom_strategy.py | 44 +++++++++++++++++++++ kernel_tuner/interface.py | 32 +++++---------- 2 files changed, 53 insertions(+), 23 deletions(-) create mode 100644 examples/cuda/vector_add_custom_strategy.py diff --git a/examples/cuda/vector_add_custom_strategy.py b/examples/cuda/vector_add_custom_strategy.py new file mode 100644 index 000000000..29d873d5d --- /dev/null +++ b/examples/cuda/vector_add_custom_strategy.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +"""This is the minimal example from the README""" + +import numpy +import kernel_tuner +from kernel_tuner import tune_kernel +from kernel_tuner.file_utils import store_output_file, store_metadata_file + +def tune(): + + kernel_string = """ + __global__ void vector_add(float *c, float *a, float *b, int n) { + int i = blockIdx.x * block_size_x + threadIdx.x; + if (i Date: Mon, 10 Mar 2025 10:15:23 +0100 Subject: [PATCH 2/5] add test for the optimization algorithm wrapper --- kernel_tuner/strategies/wrapper.py | 29 +++++ test/test_custom_optimizer.py | 163 +++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 kernel_tuner/strategies/wrapper.py create mode 100644 test/test_custom_optimizer.py diff --git a/kernel_tuner/strategies/wrapper.py b/kernel_tuner/strategies/wrapper.py new file mode 100644 index 000000000..8104f7129 --- /dev/null +++ b/kernel_tuner/strategies/wrapper.py @@ -0,0 +1,29 @@ +"""Wrapper intended for user-defined custom optimization methods""" + +from kernel_tuner import util +from kernel_tuner.searchspace import Searchspace +from kernel_tuner.strategies.common import CostFunc + + +class OptAlgWrapper: + """Wrapper class for user-defined optimization algorithms""" + + def __init__(self, optimizer, scaling=True): + self.optimizer = optimizer + self.scaling = scaling + + + def tune(self, searchspace: Searchspace, runner, tuning_options): + cost_func = CostFunc(searchspace, tuning_options, runner, scaling=self.scaling) + + if self.scaling: + # Initialize costfunc for scaling + cost_func.get_bounds_x0_eps() + + try: + self.optimizer(cost_func) + except util.StopCriterionReached as e: + if tuning_options.verbose: + print(e) + + return cost_func.results diff --git a/test/test_custom_optimizer.py b/test/test_custom_optimizer.py new file mode 100644 index 000000000..7c483bad4 --- /dev/null +++ b/test/test_custom_optimizer.py @@ -0,0 +1,163 @@ + +### The following was generating using the LLaMEA prompt and OpenAI o1 + +import numpy as np + +class HybridDELocalRefinement: + """ + A two-phase differential evolution with local refinement, intended for BBOB-type + black box optimization problems in [-5,5]^dim. + + One-line idea: A two-phase hybrid DE with local refinement that balances global + exploration and local exploitation under a strict function evaluation budget. + """ + + def __init__(self, budget, dim): + """ + Initialize the optimizer with: + - budget: total number of function evaluations allowed. + - dim: dimensionality of the search space. + """ + self.budget = budget + self.dim = dim + # You can adjust these hyperparameters based on experimentation/tuning: + self.population_size = min(50, 10 * dim) # Caps for extremely large dim + self.F = 0.8 # Differential weight + self.CR = 0.9 # Crossover probability + self.local_search_freq = 10 # Local refinement frequency in generations + + def __call__(self, func): + """ + Optimize the black box function `func` in [-5,5]^dim, using + at most self.budget function evaluations. + + Returns: + best_params: np.ndarray representing the best parameters found + best_value: float representing the best objective value found + """ + # Check if we have a non-positive budget + if self.budget <= 0: + raise ValueError("Budget must be a positive integer.") + + # 1. Initialize population + lower_bound, upper_bound = -5.0, 5.0 + pop = np.random.uniform(lower_bound, upper_bound, (self.population_size, self.dim)) + + # Evaluate initial population + evaluations = 0 + fitness = np.empty(self.population_size) + for i in range(self.population_size): + fitness[i] = func(pop[i]) + evaluations += 1 + if evaluations >= self.budget: + break + + # Track best solution + best_idx = np.argmin(fitness) + best_params = pop[best_idx].copy() + best_value = fitness[best_idx] + + # 2. Main evolutionary loop + gen = 0 + while evaluations < self.budget: + gen += 1 + for i in range(self.population_size): + # DE mutation: pick three distinct indices + idxs = np.random.choice(self.population_size, 3, replace=False) + a, b, c = pop[idxs] + mutant = a + self.F * (b - c) + + # Crossover + trial = np.copy(pop[i]) + crossover_points = np.random.rand(self.dim) < self.CR + trial[crossover_points] = mutant[crossover_points] + + # Enforce bounds + trial = np.clip(trial, lower_bound, upper_bound) + + # Evaluate trial + trial_fitness = func(trial) + evaluations += 1 + if evaluations >= self.budget: + # If out of budget, wrap up + if trial_fitness < fitness[i]: + pop[i] = trial + fitness[i] = trial_fitness + # Update global best + if trial_fitness < best_value: + best_value = trial_fitness + best_params = trial.copy() + break + + # Selection + if trial_fitness < fitness[i]: + pop[i] = trial + fitness[i] = trial_fitness + # Update global best + if trial_fitness < best_value: + best_value = trial_fitness + best_params = trial.copy() + + # Periodically refine best solution with a small local neighborhood search + if gen % self.local_search_freq == 0 and evaluations < self.budget: + best_params, best_value, evaluations = self._local_refinement( + func, best_params, best_value, evaluations, lower_bound, upper_bound + ) + + if evaluations >= self.budget: + break + + return best_params, best_value + + def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub): + """ + Local refinement around the best solution found so far. + Uses a quick 'perturb-and-accept' approach in a shrinking neighborhood. + """ + # Neighborhood size shrinks as the budget is consumed + frac_budget_used = evaluations / self.budget + step_size = 0.2 * (1.0 - frac_budget_used) + + for _ in range(5): # 5 refinements each time + if evaluations >= self.budget: + break + candidate = best_params + np.random.uniform(-step_size, step_size, self.dim) + candidate = np.clip(candidate, lb, ub) + cand_value = func(candidate) + evaluations += 1 + if cand_value < best_value: + best_value = cand_value + best_params = candidate.copy() + + return best_params, best_value, evaluations + + + + +### Testing the Optimization Algorithm Wrapper in Kernel Tuner +import os +from kernel_tuner import tune_kernel +from kernel_tuner.strategies.wrapper import OptAlgWrapper +cache_filename = os.path.dirname( + + os.path.realpath(__file__)) + "/test_cache_file.json" + +from .test_runners import env + + +def test_OptAlgWrapper(env): + kernel_name, kernel_string, size, args, tune_params = env + + # Instantiate LLaMAE optimization algorithm + budget = int(15) + dim = len(tune_params) + optimizer = HybridDELocalRefinement(budget, dim) + + # Wrap the algorithm class in the OptAlgWrapper + # for use in Kernel Tuner + strategy = OptAlgWrapper(optimizer) + + # Call the tuner + tune_kernel(kernel_name, kernel_string, size, args, tune_params, + strategy=strategy, cache=cache_filename, + simulation_mode=True, verbose=True) From 5ce2495751c9b3ed5ad734c7b2b01e602ed692f2 Mon Sep 17 00:00:00 2001 From: fjwillemsen Date: Wed, 4 Jun 2025 13:27:19 +0200 Subject: [PATCH 3/5] Implemented abstract base class for custom optimization algorithm strategies --- kernel_tuner/strategies/wrapper.py | 35 ++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/kernel_tuner/strategies/wrapper.py b/kernel_tuner/strategies/wrapper.py index 8104f7129..a0aa56a4b 100644 --- a/kernel_tuner/strategies/wrapper.py +++ b/kernel_tuner/strategies/wrapper.py @@ -1,27 +1,48 @@ """Wrapper intended for user-defined custom optimization methods""" +from abc import ABC, abstractmethod + from kernel_tuner import util from kernel_tuner.searchspace import Searchspace from kernel_tuner.strategies.common import CostFunc +class OptAlg(ABC): + """Base class for user-defined optimization algorithms.""" + + def __init__(self): + self.costfunc_kwargs = {"scaling": True, "snap": True} + + @abstractmethod + def __call__(self, func: CostFunc, searchspace: Searchspace, budget_spent_fraction: float) -> tuple[tuple, float]: + """_summary_ + + Args: + func (CostFunc): Cost function to be optimized. + searchspace (Searchspace): Search space containing the parameters to be optimized. + budget_spent_fraction (float): Fraction of the budget that has already been spent. + + Returns: + tuple[tuple, float]: tuple of the best parameters and the corresponding cost value + """ + pass + + class OptAlgWrapper: """Wrapper class for user-defined optimization algorithms""" - def __init__(self, optimizer, scaling=True): - self.optimizer = optimizer - self.scaling = scaling - + def __init__(self, optimizer: OptAlg): + self.optimizer: OptAlg = optimizer def tune(self, searchspace: Searchspace, runner, tuning_options): - cost_func = CostFunc(searchspace, tuning_options, runner, scaling=self.scaling) + cost_func = CostFunc(searchspace, tuning_options, runner, **self.optimizer.costfunc_kwargs) - if self.scaling: + if self.optimizer.costfunc_kwargs.get('scaling', True): # Initialize costfunc for scaling cost_func.get_bounds_x0_eps() try: - self.optimizer(cost_func) + self.optimizer(cost_func, searchspace) except util.StopCriterionReached as e: if tuning_options.verbose: print(e) From 22a6e0e84df2852fd99163079530639e3e6e66f5 Mon Sep 17 00:00:00 2001 From: fjwillemsen Date: Wed, 4 Jun 2025 15:22:50 +0200 Subject: [PATCH 4/5] returns the fraction of the budget that has been spent --- kernel_tuner/strategies/common.py | 3 ++- kernel_tuner/util.py | 28 ++++++++++++++++++++++------ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/kernel_tuner/strategies/common.py b/kernel_tuner/strategies/common.py index d01eae937..658094141 100644 --- a/kernel_tuner/strategies/common.py +++ b/kernel_tuner/strategies/common.py @@ -60,6 +60,7 @@ def __init__(self, searchspace: Searchspace, tuning_options, runner, *, scaling= self.scaling = scaling self.searchspace = searchspace self.results = [] + self.budget_spent_fraction = 0.0 def __call__(self, x, check_restrictions=True): """Cost function used by almost all strategies.""" @@ -70,7 +71,7 @@ def __call__(self, x, check_restrictions=True): logging.debug('x: ' + str(x)) # check if max_fevals is reached or time limit is exceeded - util.check_stop_criterion(self.tuning_options) + self.budget_spent_fraction = util.check_stop_criterion(self.tuning_options) # snap values in x to nearest actual value for each parameter, unscale x if needed if self.snap: diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 072cce433..7713adbde 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -189,12 +189,28 @@ def check_argument_list(kernel_name, kernel_string, args): warnings.warn(errors[0], UserWarning) -def check_stop_criterion(to): - """Checks if max_fevals is reached or time limit is exceeded.""" - if "max_fevals" in to and len(to.unique_results) >= to.max_fevals: - raise StopCriterionReached("max_fevals reached") - if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit): - raise StopCriterionReached("time limit exceeded") +def check_stop_criterion(to: dict) -> float: + """Check if the stop criterion is reached. + + Args: + to (dict): tuning options. + + Raises: + StopCriterionReached: if the max_fevals is reached or time limit is exceeded. + + Returns: + float: fraction of budget spent. + """ + if "max_fevals" in to: + if len(to.unique_results) >= to.max_fevals: + raise StopCriterionReached(f"max_fevals ({to.max_fevals}) reached") + return len(to.unique_results) / to.max_fevals + if "time_limit" in to: + time_spent = (time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3) + if time_spent > to.time_limit: + raise StopCriterionReached("time limit exceeded") + return time_spent / to.time_limit + def check_tune_params_list(tune_params, observers, simulation_mode=False): From a4a69ae04e0d9aa23ea75e4840499f9f50f93f44 Mon Sep 17 00:00:00 2001 From: fjwillemsen Date: Wed, 4 Jun 2025 15:24:28 +0200 Subject: [PATCH 5/5] Adjusted CostFunc and tests to use budget_spent_fraction --- kernel_tuner/strategies/wrapper.py | 7 ++--- test/test_custom_optimizer.py | 49 +++++++++++------------------- 2 files changed, 20 insertions(+), 36 deletions(-) diff --git a/kernel_tuner/strategies/wrapper.py b/kernel_tuner/strategies/wrapper.py index a0aa56a4b..1a928ab17 100644 --- a/kernel_tuner/strategies/wrapper.py +++ b/kernel_tuner/strategies/wrapper.py @@ -14,13 +14,12 @@ def __init__(self): self.costfunc_kwargs = {"scaling": True, "snap": True} @abstractmethod - def __call__(self, func: CostFunc, searchspace: Searchspace, budget_spent_fraction: float) -> tuple[tuple, float]: - """_summary_ + def __call__(self, func: CostFunc, searchspace: Searchspace) -> tuple[tuple, float]: + """Optimize the black box function `func` within the given `searchspace`. Args: - func (CostFunc): Cost function to be optimized. + func (CostFunc): Cost function to be optimized. Has a property `budget_spent_fraction` that indicates how much of the budget has been spent. searchspace (Searchspace): Search space containing the parameters to be optimized. - budget_spent_fraction (float): Fraction of the budget that has already been spent. Returns: tuple[tuple, float]: tuple of the best parameters and the corresponding cost value diff --git a/test/test_custom_optimizer.py b/test/test_custom_optimizer.py index 7c483bad4..cfc136d3c 100644 --- a/test/test_custom_optimizer.py +++ b/test/test_custom_optimizer.py @@ -3,7 +3,9 @@ import numpy as np -class HybridDELocalRefinement: +from kernel_tuner.strategies.wrapper import OptAlg + +class HybridDELocalRefinement(OptAlg): """ A two-phase differential evolution with local refinement, intended for BBOB-type black box optimization problems in [-5,5]^dim. @@ -12,21 +14,14 @@ class HybridDELocalRefinement: exploration and local exploitation under a strict function evaluation budget. """ - def __init__(self, budget, dim): - """ - Initialize the optimizer with: - - budget: total number of function evaluations allowed. - - dim: dimensionality of the search space. - """ - self.budget = budget - self.dim = dim + def __init__(self): + super().__init__() # You can adjust these hyperparameters based on experimentation/tuning: - self.population_size = min(50, 10 * dim) # Caps for extremely large dim self.F = 0.8 # Differential weight self.CR = 0.9 # Crossover probability self.local_search_freq = 10 # Local refinement frequency in generations - def __call__(self, func): + def __call__(self, func, searchspace): """ Optimize the black box function `func` in [-5,5]^dim, using at most self.budget function evaluations. @@ -35,9 +30,8 @@ def __call__(self, func): best_params: np.ndarray representing the best parameters found best_value: float representing the best objective value found """ - # Check if we have a non-positive budget - if self.budget <= 0: - raise ValueError("Budget must be a positive integer.") + self.dim = searchspace.num_params + self.population_size = round(min(min(50, 10 * self.dim), np.ceil(searchspace.size / 3))) # Caps for extremely large dim # 1. Initialize population lower_bound, upper_bound = -5.0, 5.0 @@ -49,8 +43,6 @@ def __call__(self, func): for i in range(self.population_size): fitness[i] = func(pop[i]) evaluations += 1 - if evaluations >= self.budget: - break # Track best solution best_idx = np.argmin(fitness) @@ -59,7 +51,7 @@ def __call__(self, func): # 2. Main evolutionary loop gen = 0 - while evaluations < self.budget: + while func.budget_spent_fraction < 1.0 and evaluations < searchspace.size: gen += 1 for i in range(self.population_size): # DE mutation: pick three distinct indices @@ -78,7 +70,7 @@ def __call__(self, func): # Evaluate trial trial_fitness = func(trial) evaluations += 1 - if evaluations >= self.budget: + if func.budget_spent_fraction > 1.0: # If out of budget, wrap up if trial_fitness < fitness[i]: pop[i] = trial @@ -99,14 +91,11 @@ def __call__(self, func): best_params = trial.copy() # Periodically refine best solution with a small local neighborhood search - if gen % self.local_search_freq == 0 and evaluations < self.budget: + if gen % self.local_search_freq == 0 and func.budget_spent_fraction < 1.0: best_params, best_value, evaluations = self._local_refinement( func, best_params, best_value, evaluations, lower_bound, upper_bound ) - if evaluations >= self.budget: - break - return best_params, best_value def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub): @@ -115,11 +104,10 @@ def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub): Uses a quick 'perturb-and-accept' approach in a shrinking neighborhood. """ # Neighborhood size shrinks as the budget is consumed - frac_budget_used = evaluations / self.budget - step_size = 0.2 * (1.0 - frac_budget_used) + step_size = 0.2 * (1.0 - func.budget_spent_fraction) for _ in range(5): # 5 refinements each time - if evaluations >= self.budget: + if func.budget_spent_fraction >= 1.0: break candidate = best_params + np.random.uniform(-step_size, step_size, self.dim) candidate = np.clip(candidate, lb, ub) @@ -138,26 +126,23 @@ def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub): import os from kernel_tuner import tune_kernel from kernel_tuner.strategies.wrapper import OptAlgWrapper -cache_filename = os.path.dirname( - - os.path.realpath(__file__)) + "/test_cache_file.json" from .test_runners import env +cache_filename = os.path.dirname(os.path.realpath(__file__)) + "/test_cache_file.json" def test_OptAlgWrapper(env): kernel_name, kernel_string, size, args, tune_params = env # Instantiate LLaMAE optimization algorithm - budget = int(15) - dim = len(tune_params) - optimizer = HybridDELocalRefinement(budget, dim) + optimizer = HybridDELocalRefinement() # Wrap the algorithm class in the OptAlgWrapper # for use in Kernel Tuner strategy = OptAlgWrapper(optimizer) + strategy_options = { 'max_fevals': 15 } # Call the tuner tune_kernel(kernel_name, kernel_string, size, args, tune_params, - strategy=strategy, cache=cache_filename, + strategy=strategy, strategy_options=strategy_options, cache=cache_filename, simulation_mode=True, verbose=True)