Skip to content

Commit fc0456f

Browse files
committed
single source of truth for bounds
1 parent 56de7fa commit fc0456f

File tree

6 files changed

+133
-112
lines changed

6 files changed

+133
-112
lines changed

include/bounds.hpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,27 @@ struct Population;
99
namespace parameters
1010
{
1111
struct Parameters;
12+
struct Settings;
1213
}
1314

1415
namespace bounds
1516
{
1617
using Mask = Eigen::Array<bool, Eigen::Dynamic, 1>;
1718

18-
Mask is_out_of_bounds(const Vector& xi, const Vector& lb, const Vector& ub);
19-
bool any_out_of_bounds(const Vector& xi, const Vector& lb, const Vector& ub);
19+
Mask is_out_of_bounds(const Vector &xi, const Vector &lb, const Vector &ub);
20+
bool any_out_of_bounds(const Vector &xi, const Vector &lb, const Vector &ub);
2021

2122
struct BoundCorrection
2223
{
2324
virtual ~BoundCorrection() = default;
24-
Vector lb, ub, db;
25+
Vector db;
2526
Float diameter;
2627
size_t n_out_of_bounds = 0;
2728
bool has_bounds;
2829

29-
BoundCorrection(const Vector& lb, const Vector& ub) : lb(lb), ub(ub), db(ub - lb),
30-
diameter((ub - lb).norm()),
31-
has_bounds(true)
30+
BoundCorrection(const Vector &lb, const Vector &ub) : db(ub - lb),
31+
diameter((ub - lb).norm()),
32+
has_bounds(true)
3233
{
3334
//! find a better way
3435
if (!std::isfinite(diameter))
@@ -38,13 +39,22 @@ namespace bounds
3839
}
3940
}
4041

41-
void correct(const Eigen::Index i, parameters::Parameters& p);
42+
void correct(const Eigen::Index i, parameters::Parameters &p);
4243

43-
virtual Vector correct_x(const Vector& xi, const Mask& oob, const Float sigma) = 0;
44+
virtual Vector correct_x(
45+
const Vector &xi,
46+
const Mask &oob,
47+
const Float sigma,
48+
const parameters::Settings &settings) = 0;
4449

45-
[[nodiscard]] Mask is_out_of_bounds(const Vector& xi) const;
50+
[[nodiscard]] Mask is_out_of_bounds(
51+
const Vector &xi,
52+
const parameters::Settings &settings) const;
4653

47-
[[nodiscard]] Vector delta_out_of_bounds(const Vector& xi, const Mask& oob) const;
54+
[[nodiscard]] Vector delta_out_of_bounds(
55+
const Vector &xi,
56+
const Mask &oob,
57+
const parameters::Settings &settings) const;
4858

4959
[[nodiscard]] bool any_out_of_bounds() const
5060
{
@@ -56,7 +66,7 @@ namespace bounds
5666
{
5767
using BoundCorrection::BoundCorrection;
5868

59-
Vector correct_x(const Vector& xi, const Mask& oob, const Float sigma) override
69+
Vector correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings) override
6070
{
6171
return xi;
6272
}
@@ -73,14 +83,14 @@ namespace bounds
7383

7484
COTN(Eigen::Ref<const Vector> lb, Eigen::Ref<const Vector> ub) : BoundCorrection(lb, ub), sampler(static_cast<size_t>(lb.size()), rng::normal<Float>(0, 1.0 / 3.)) {}
7585

76-
Vector correct_x(const Vector& xi, const Mask& oob, const Float sigma) override;
86+
Vector correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings) override;
7787
};
7888

7989
struct Mirror final : BoundCorrection
8090
{
8191
using BoundCorrection::BoundCorrection;
8292

83-
Vector correct_x(const Vector& xi, const Mask& oob, const Float sigma) override;
93+
Vector correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings) override;
8494
};
8595

8696
struct UniformResample final : BoundCorrection
@@ -89,24 +99,24 @@ namespace bounds
8999

90100
UniformResample(Eigen::Ref<const Vector> lb, Eigen::Ref<const Vector> ub) : BoundCorrection(lb, ub), sampler(static_cast<size_t>(lb.size())) {}
91101

92-
Vector correct_x(const Vector& xi, const Mask& oob, const Float sigma) override;
102+
Vector correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings) override;
93103
};
94104

95105
struct Saturate final : BoundCorrection
96106
{
97107
using BoundCorrection::BoundCorrection;
98108

99-
Vector correct_x(const Vector& xi, const Mask& oob, const Float sigma) override;
109+
Vector correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings) override;
100110
};
101111

102112
struct Toroidal final : BoundCorrection
103113
{
104114
using BoundCorrection::BoundCorrection;
105115

106-
Vector correct_x(const Vector& xi, const Mask& oob, const Float sigma) override;
116+
Vector correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings) override;
107117
};
108118

109-
inline std::shared_ptr<BoundCorrection> get(const parameters::CorrectionMethod& m, const Vector& lb, const Vector& ub)
119+
inline std::shared_ptr<BoundCorrection> get(const parameters::CorrectionMethod &m, const Vector &lb, const Vector &ub)
110120
{
111121
using namespace parameters;
112122
switch (m)

include/es.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "sampling.hpp"
44
#include "stats.hpp"
55
#include "bounds.hpp"
6+
#include "settings.hpp"
67

78
namespace es
89
{
@@ -20,8 +21,15 @@ namespace es
2021
x(x0), f(f0), t(1), budget(budget), target(target),
2122
rejection_sampling(modules.bound_correction == parameters::CorrectionMethod::RESAMPLE),
2223
sampler(sampling::get(d, modules, 1)),
23-
corrector(bounds::get(modules.bound_correction, Vector::Ones(d) * -5.0, Vector::Ones(d) * 5.0))
24+
corrector(bounds::get(modules.bound_correction, Vector::Ones(d) * -5.0, Vector::Ones(d) * 5.0)),
25+
settings{d}
2426
{
27+
settings.modules = modules;
28+
settings.x0 = x0;
29+
settings.budget = budget;
30+
settings.target = target;
31+
settings.lb = Vector::Ones(d) * -5.0;
32+
settings.ub = Vector::Ones(d) * 5.0;
2533
}
2634

2735
Vector sample();
@@ -40,6 +48,7 @@ namespace es
4048

4149
std::shared_ptr<sampling::Sampler> sampler;
4250
std::shared_ptr<bounds::BoundCorrection> corrector;
51+
parameters::Settings settings;
4352
};
4453

4554
struct MuCommaLambdaES
@@ -65,10 +74,18 @@ namespace es
6574
sampler(sampling::get(d, modules, lambda)),
6675
sigma_sampler(std::make_shared<sampling::Gaussian>(d)),
6776
rejection_sampling(modules.bound_correction == parameters::CorrectionMethod::RESAMPLE),
68-
corrector(bounds::get(modules.bound_correction, Vector::Ones(d) * -5.0, Vector::Ones(d) * 5.0))
77+
corrector(bounds::get(modules.bound_correction, Vector::Ones(d) * -5.0, Vector::Ones(d) * 5.0)),
78+
settings(d)
6979
{
7080
// tau = 1.0 / sampler->expected_length();
7181
// tau_i = 1.0 / std::sqrt(sampler->expected_length());
82+
83+
settings.modules = modules;
84+
settings.x0 = x0;
85+
settings.budget = budget;
86+
settings.target = target;
87+
settings.lb = Vector::Ones(d) * -5.0;
88+
settings.ub = Vector::Ones(d) * 5.0;
7289
}
7390

7491
Vector sample(const Vector si);
@@ -100,5 +117,6 @@ namespace es
100117

101118
bool rejection_sampling;
102119
std::shared_ptr<bounds::BoundCorrection> corrector;
120+
parameters::Settings settings;
103121
};
104122
}

src/bounds.cpp

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ static Float modulo2(const int x)
99

1010
namespace bounds
1111
{
12-
13-
Mask is_out_of_bounds(const Vector &xi, const Vector &lb, const Vector &ub)
12+
13+
Mask is_out_of_bounds(const Vector &xi, const Vector &lb, const Vector &ub)
1414
{
1515
return xi.array() < lb.array() || xi.array() > ub.array();
1616
}
@@ -20,69 +20,66 @@ namespace bounds
2020
return bounds::is_out_of_bounds(xi, lb, ub).any();
2121
}
2222

23-
24-
Mask BoundCorrection::is_out_of_bounds(const Vector& xi) const
23+
Mask BoundCorrection::is_out_of_bounds(const Vector &xi, const parameters::Settings &settings) const
2524
{
26-
return bounds::is_out_of_bounds(xi, lb, ub);
25+
return bounds::is_out_of_bounds(xi, settings.lb, settings.ub);
2726
}
2827

29-
Vector BoundCorrection::delta_out_of_bounds(const Vector& xi, const Mask& oob) const
28+
Vector BoundCorrection::delta_out_of_bounds(const Vector &xi, const Mask &oob, const parameters::Settings &settings) const
3029
{
31-
return (oob).select((xi - lb).cwiseQuotient(db), xi);;
30+
return (oob).select((xi - settings.lb).cwiseQuotient(db), xi);
31+
;
3232
}
3333

34-
void BoundCorrection::correct(const Eigen::Index i, parameters::Parameters& p)
34+
void BoundCorrection::correct(const Eigen::Index i, parameters::Parameters &p)
3535
{
3636
if (!has_bounds)
3737
return;
3838

39-
const auto oob = is_out_of_bounds(p.pop.X.col(i));
39+
const auto oob = is_out_of_bounds(p.pop.X.col(i), p.settings);
4040
if (oob.any())
4141
{
4242
n_out_of_bounds++;
4343
if (p.settings.modules.bound_correction == parameters::CorrectionMethod::NONE)
4444
return;
4545

46-
p.pop.X.col(i) = correct_x(p.pop.X.col(i), oob, p.mutation->sigma);
46+
p.pop.X.col(i) = correct_x(p.pop.X.col(i), oob, p.mutation->sigma, p.settings);
4747
p.pop.Y.col(i) = p.adaptation->invert_x(p.pop.X.col(i), p.pop.s(i));
4848
p.pop.Z.col(i) = p.adaptation->invert_y(p.pop.Y.col(i));
4949
}
5050
}
5151

52-
Vector COTN::correct_x(const Vector& xi, const Mask& oob, const Float sigma)
52+
Vector COTN::correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings)
5353
{
54-
const Vector y = delta_out_of_bounds(xi, oob);
54+
const Vector y = delta_out_of_bounds(xi, oob, settings);
5555
return (oob).select(
56-
lb.array() + db.array() * ((y.array() > 0).cast<Float>() - (sigma * sampler().array().abs())).abs(), y);
56+
settings.lb.array() + db.array() * ((y.array() > 0).cast<Float>() - (sigma * sampler().array().abs())).abs(), y);
5757
}
5858

59-
60-
Vector Mirror::correct_x(const Vector& xi, const Mask& oob, const Float sigma)
59+
Vector Mirror::correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings)
6160
{
62-
const Vector y = delta_out_of_bounds(xi, oob);
61+
const Vector y = delta_out_of_bounds(xi, oob, settings);
6362
return (oob).select(
64-
lb.array() + db.array() * (y.array() - y.array().floor() - y.array().floor().unaryExpr(&modulo2)).
65-
abs(),
63+
settings.lb.array() + db.array() * (y.array() - y.array().floor() - y.array().floor().unaryExpr(&modulo2)).abs(),
6664
y);
6765
}
6866

69-
70-
Vector UniformResample::correct_x(const Vector& xi, const Mask& oob, const Float sigma)
67+
Vector UniformResample::correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings)
7168
{
72-
return (oob).select(lb + sampler().cwiseProduct(db), xi);
69+
return (oob).select(settings.lb + sampler().cwiseProduct(db), xi);
7370
}
7471

75-
Vector Saturate::correct_x(const Vector& xi, const Mask& oob, const Float sigma)
72+
Vector Saturate::correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings)
7673
{
77-
const Vector y = delta_out_of_bounds(xi, oob);
74+
const Vector y = delta_out_of_bounds(xi, oob, settings);
7875
return (oob).select(
79-
lb.array() + db.array() * (y.array() > 0).cast<Float>(), y);
76+
settings.lb.array() + db.array() * (y.array() > 0).cast<Float>(), y);
8077
}
8178

82-
Vector Toroidal::correct_x(const Vector& xi, const Mask& oob, const Float sigma)
79+
Vector Toroidal::correct_x(const Vector &xi, const Mask &oob, const Float sigma, const parameters::Settings &settings)
8380
{
84-
const Vector y = delta_out_of_bounds(xi, oob);
81+
const Vector y = delta_out_of_bounds(xi, oob, settings);
8582
return (oob).select(
86-
lb.array() + db.array() * (y.array() - y.array().floor()).abs(), y);
83+
settings.lb.array() + db.array() * (y.array() - y.array().floor()).abs(), y);
8784
}
8885
}

src/es.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ namespace es
1313
const Vector z = (*sampler)();
1414
x1 = x + sigma * z;
1515

16-
const auto mask = corrector->is_out_of_bounds(x1);
16+
const auto mask = corrector->is_out_of_bounds(x, settings);
1717
if (mask.any())
18-
x1 = corrector->correct_x(x1, mask, sigma);
18+
x1 = corrector->correct_x(x1, mask, sigma, settings);
1919

20-
} while (rejection_sampling && n_rej++ < 5*d && bounds::any_out_of_bounds(x1, corrector->lb, corrector->ub) );
20+
} while (rejection_sampling && n_rej++ < 5*d && bounds::any_out_of_bounds(x1, settings.lb, settings.ub) );
2121
return x1;
2222
}
2323

@@ -49,11 +49,11 @@ namespace es
4949
const Vector z = (*sampler)();
5050
x = m.array() + (si.array() * z.array());
5151

52-
const auto mask = corrector->is_out_of_bounds(x);
52+
const auto mask = corrector->is_out_of_bounds(x, settings);
5353
if (mask.any())
54-
x = corrector->correct_x(x, mask, si.mean());
54+
x = corrector->correct_x(x, mask, si.mean(), settings);
5555

56-
} while (rejection_sampling && n_rej++ < 5*d && bounds::any_out_of_bounds(x, corrector->lb, corrector->ub));
56+
} while (rejection_sampling && n_rej++ < 5*d && bounds::any_out_of_bounds(x, settings.lb, settings.ub));
5757
return x;
5858
}
5959

src/interface.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -745,16 +745,15 @@ void define_bounds(py::module& main)
745745
using namespace bounds;
746746

747747
py::class_<BoundCorrection, std::shared_ptr<BoundCorrection>>(m, "BoundCorrection")
748-
.def_readwrite("lb", &BoundCorrection::lb)
749-
.def_readwrite("ub", &BoundCorrection::ub)
750748
.def_readwrite("db", &BoundCorrection::db)
751749
.def_readwrite("diameter", &BoundCorrection::diameter)
752750
.def_readwrite("has_bounds", &BoundCorrection::has_bounds)
753751
.def_readonly("n_out_of_bounds", &BoundCorrection::n_out_of_bounds)
754752
.def("correct", &BoundCorrection::correct,
755-
py::arg("population"), py::arg("m"))
756-
.def("delta_out_of_bounds", &BoundCorrection::delta_out_of_bounds, py::arg("xi"), py::arg("oob"))
757-
.def("is_out_of_bounds", &BoundCorrection::is_out_of_bounds, py::arg("xi"))
753+
py::arg("index"), py::arg("parameters"))
754+
.def("correct_x", &BoundCorrection::correct_x, py::arg("xi"), py::arg("oob"), py::arg("sigma"), py::arg("settings"))
755+
.def("delta_out_of_bounds", &BoundCorrection::delta_out_of_bounds, py::arg("xi"), py::arg("oob"), py::arg("settings"))
756+
.def("is_out_of_bounds", &BoundCorrection::is_out_of_bounds, py::arg("xi"), py::arg("settings"))
758757
;
759758

760759
py::class_<Resample, BoundCorrection, std::shared_ptr<Resample>>(m, "Resample")

0 commit comments

Comments
 (0)