|
10 | 10 | "import pandas as pd\n", |
11 | 11 | "import matplotlib.pyplot as plt\n", |
12 | 12 | "import seaborn as sns \n", |
13 | | - "from modcma.c_maes import sampling\n", |
| 13 | + "from modcma.c_maes import (\n", |
| 14 | + " sampling\n", |
| 15 | + " ModularCMAES, parameters, options, constants, utils, es\n", |
| 16 | + ")\n", |
14 | 17 | "\n", |
15 | 18 | "\n", |
16 | 19 | "import matplotlib\n", |
|
93 | 96 | "plt.savefig(\"figures/distributions_z.pdf\")\n" |
94 | 97 | ] |
95 | 98 | }, |
96 | | - { |
97 | | - "cell_type": "code", |
98 | | - "execution_count": 7, |
99 | | - "metadata": {}, |
100 | | - "outputs": [ |
101 | | - { |
102 | | - "data": { |
103 | | - "text/plain": [ |
104 | | - "(0.12156862745098039, 0.4666666666666667, 0.7058823529411765)" |
105 | | - ] |
106 | | - }, |
107 | | - "execution_count": 7, |
108 | | - "metadata": {}, |
109 | | - "output_type": "execute_result" |
110 | | - } |
111 | | - ], |
112 | | - "source": [] |
113 | | - }, |
114 | 99 | { |
115 | 100 | "cell_type": "code", |
116 | 101 | "execution_count": 4, |
|
228 | 213 | }, |
229 | 214 | { |
230 | 215 | "cell_type": "code", |
231 | | - "execution_count": 6, |
| 216 | + "execution_count": null, |
| 217 | + "metadata": {}, |
| 218 | + "outputs": [], |
| 219 | + "source": [ |
| 220 | + "d = 2 \n", |
| 221 | + "\n", |
| 222 | + "def sphere(x):\n", |
| 223 | + " x = np.asarray(x)\n", |
| 224 | + " return x.dot(x)\n", |
| 225 | + "\n", |
| 226 | + "def get_meshgrid(objective_function, lb, ub, delta: float = 0.025):\n", |
| 227 | + " x = np.arange(lb, ub + delta, delta)\n", |
| 228 | + " y = np.arange(lb, ub + delta, delta)\n", |
| 229 | + "\n", |
| 230 | + " if hasattr(objective_function, \"optimum\"):\n", |
| 231 | + " xo, yo = objective_function.optimum.x\n", |
| 232 | + " x = np.sort(np.r_[x, xo])\n", |
| 233 | + " y = np.sort(np.r_[y, yo])\n", |
| 234 | + "\n", |
| 235 | + " X, Y = np.meshgrid(x, y)\n", |
| 236 | + "\n", |
| 237 | + " Z = np.zeros(X.shape)\n", |
| 238 | + " for idx1 in range(X.shape[0]):\n", |
| 239 | + " for idx2 in range(X.shape[1]):\n", |
| 240 | + " Z[idx1, idx2] = objective_function([X[idx1, idx2], Y[idx1, idx2]])\n", |
| 241 | + " return X, Y, Z\n", |
| 242 | + "\n", |
| 243 | + "\n", |
| 244 | + "X, Y, Z = get_meshgrid(sphere, -5, 3)\n", |
| 245 | + "\n", |
| 246 | + "x0 = np.array([-4, -4])\n", |
| 247 | + "\n", |
| 248 | + "\n", |
| 249 | + "modules = parameters.Modules()\n", |
| 250 | + "modules.sample_transformation = options.SampleTranformerType(1)\n", |
| 251 | + "settings = parameters.Settings(dim=2, modules=modules, x0=x0, sigma0=2)\n", |
| 252 | + "\n", |
| 253 | + "utils.set_seed(10)\n", |
| 254 | + "cma = ModularCMAES(settings)\n", |
| 255 | + "\n", |
| 256 | + "\n", |
| 257 | + "f, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 5), sharex=True, sharey=True)\n", |
| 258 | + "\n", |
| 259 | + "ax1.contourf(\n", |
| 260 | + " X, Y, np.log10(Z), levels=200, cmap=\"Spectral\", zorder=-1, vmin=-1, vmax=2.5\n", |
| 261 | + ")\n", |
| 262 | + "ax2.contourf(\n", |
| 263 | + " X, Y, np.log10(Z), levels=200, cmap=\"Spectral\", zorder=-1, vmin=-1, vmax=2.5\n", |
| 264 | + ")\n", |
| 265 | + "\n", |
| 266 | + "for i in range(3):\n", |
| 267 | + " m = cma.p.adaptation.m.copy()\n", |
| 268 | + " C = cma.p.adaptation.C.copy()\n", |
| 269 | + " sigma = cma.p.mutation.sigma\n", |
| 270 | + " theta = np.degrees(np.arctan2(C[1, 0], C[0, 0]))\n", |
| 271 | + " \n", |
| 272 | + " color = 'black'\n", |
| 273 | + " p = ax1.scatter(*m, label=i, color=color)\n", |
| 274 | + " \n", |
| 275 | + " current = Ellipse(\n", |
| 276 | + " m,\n", |
| 277 | + " *(sigma * np.diag(C)),\n", |
| 278 | + " angle=theta,\n", |
| 279 | + " facecolor=\"none\",\n", |
| 280 | + " edgecolor=p.get_edgecolor(),\n", |
| 281 | + " linewidth=2,\n", |
| 282 | + " linestyle=\"dashed\",\n", |
| 283 | + " zorder=0,\n", |
| 284 | + " )\n", |
| 285 | + " ax1.add_patch(current)\n", |
| 286 | + " cma.step(sphere)\n", |
| 287 | + " \n", |
| 288 | + " \n", |
| 289 | + " \n", |
| 290 | + "modules = parameters.Modules()\n", |
| 291 | + "modules.sample_transformation = options.SampleTranformerType(2)\n", |
| 292 | + "settings = parameters.Settings(dim=2, modules=modules, x0=x0, sigma0=2)\n", |
| 293 | + "\n", |
| 294 | + "utils.set_seed(10)\n", |
| 295 | + "cma = ModularCMAES(settings)\n", |
| 296 | + "\n", |
| 297 | + "for i in range(3):\n", |
| 298 | + " m = cma.p.adaptation.m.copy()\n", |
| 299 | + " C = cma.p.adaptation.C.copy()\n", |
| 300 | + " sigma = cma.p.mutation.sigma\n", |
| 301 | + " theta = np.degrees(np.arctan2(C[1, 0], C[0, 0]))\n", |
| 302 | + " \n", |
| 303 | + " color = 'black'\n", |
| 304 | + " p = ax2.scatter(*m, label=i, color=color)\n", |
| 305 | + " \n", |
| 306 | + " \n", |
| 307 | + " width = sigma * C[0, 0]\n", |
| 308 | + " height = sigma * C[1, 1]\n", |
| 309 | + " \n", |
| 310 | + " current = Rectangle(\n", |
| 311 | + " (-width / 2, -height / 2), width, height,\n", |
| 312 | + " facecolor=\"none\",\n", |
| 313 | + " edgecolor=p.get_edgecolor(), \n", |
| 314 | + " linewidth=2,\n", |
| 315 | + " linestyle=\"dashed\",\n", |
| 316 | + " zorder=0, \n", |
| 317 | + " )\n", |
| 318 | + " transformation = (\n", |
| 319 | + " Affine2D()\n", |
| 320 | + " .rotate_deg(theta) \n", |
| 321 | + " .translate(*m)\n", |
| 322 | + " + ax2.transData \n", |
| 323 | + " )\n", |
| 324 | + " \n", |
| 325 | + " current.set_transform(transformation)\n", |
| 326 | + " \n", |
| 327 | + " ax2.add_patch(current)\n", |
| 328 | + " cma.step(sphere)\n", |
| 329 | + " \n", |
| 330 | + "\n", |
| 331 | + "for ax in ax1, ax2:\n", |
| 332 | + " ax.set_aspect(\"equal\")\n", |
| 333 | + " ax.set_ylim(-5, 0)\n", |
| 334 | + " ax.set_xticks([])\n", |
| 335 | + " ax.set_yticks([])\n", |
| 336 | + "plt.tight_layout()\n", |
| 337 | + "plt.savefig(\"figures/adaptation.pdf\")" |
| 338 | + ] |
| 339 | + }, |
| 340 | + { |
| 341 | + "cell_type": "code", |
| 342 | + "execution_count": null, |
232 | 343 | "metadata": {}, |
233 | 344 | "outputs": [], |
234 | 345 | "source": [ |
235 | | - "vector = np.array([0, -1])\n", |
| 346 | + "def get_one_plus_one(problem, dim, sampler):\n", |
| 347 | + " modules = parameters.Modules()\n", |
| 348 | + " modules.sample_transformation = options.SampleTranformerType(sampler)\n", |
| 349 | + " \n", |
| 350 | + " x0 = np.random.uniform(-5, 5, size=dim)\n", |
| 351 | + " return es.OnePlusOneES(\n", |
| 352 | + " dim,\n", |
| 353 | + " x0=x0,\n", |
| 354 | + " f0=problem(x0),\n", |
| 355 | + " sigma0=1,\n", |
| 356 | + " modules=modules,\n", |
| 357 | + " )\n", |
| 358 | + " \n", |
| 359 | + "names = sorted([options.SampleTranformerType(sampler).name.title().replace(\"Scaled_\", \"\").replace(\"Double_\", \"d\") \n", |
| 360 | + " for sampler in range(1, 7)], key=lambda x:x.lower())\n", |
| 361 | + "colors = dict(zip(names, sns.color_palette(\"tab10\")))\n", |
| 362 | + "\n", |
| 363 | + "n_evals = 10000\n", |
| 364 | + "n_runs = 1000\n", |
| 365 | + "\n", |
| 366 | + "linestyle = {\n", |
| 367 | + " 2: \"solid\",\n", |
| 368 | + " 10: \"dashed\",\n", |
| 369 | + " 50: \"dotted\"\n", |
| 370 | + "}\n", |
| 371 | + "\n", |
| 372 | + "np.random.seed(1)\n", |
| 373 | + "utils.set_seed(1)\n", |
| 374 | + "linewidth = 2\n", |
236 | 375 | "\n", |
237 | | - "\n" |
| 376 | + "f, ax = plt.subplots(figsize=(7, 4))\n", |
| 377 | + "for sampler in range(1, 7):\n", |
| 378 | + " for d in (2, 10, 50):\n", |
| 379 | + " f = np.zeros(n_evals)\n", |
| 380 | + " s = np.zeros(n_evals)\n", |
| 381 | + " for r in range(n_runs):\n", |
| 382 | + " alg = get_one_plus_one(sphere, d, sampler)\n", |
| 383 | + " for e in range(n_evals):\n", |
| 384 | + " f[e] += alg.f\n", |
| 385 | + " s[e] += alg.sigma\n", |
| 386 | + " alg.step(sphere)\n", |
| 387 | + " \n", |
| 388 | + " f /= n_runs\n", |
| 389 | + " s /= n_runs\n", |
| 390 | + " sampler_name = get_name(alg.sampler.__class__)\n", |
| 391 | + " ax.plot(s, color=colors[sampler_name], linestyle=linestyle[d], label=sampler_name, linewidth=linewidth)\n", |
| 392 | + "\n", |
| 393 | + "ax.legend()\n", |
| 394 | + "ax.set_yscale(\"log\")\n", |
| 395 | + "ax.set_xscale(\"log\")\n", |
| 396 | + "\n", |
| 397 | + "from matplotlib.lines import Line2D\n", |
| 398 | + "\n", |
| 399 | + "handles = [Line2D([0], [0], linestyle='')]\n", |
| 400 | + "labels = [\"$\\\\bf{Sampler}$\"]\n", |
| 401 | + "\n", |
| 402 | + "for alg, color in colors.items():\n", |
| 403 | + " handles.append(Line2D([0], [0], label=alg, color=color, linewidth=linewidth))\n", |
| 404 | + " labels.append(alg)\n", |
| 405 | + "\n", |
| 406 | + "\n", |
| 407 | + "handles.append(Line2D([0], [0], linestyle=''))\n", |
| 408 | + "labels.append(\"$\\\\mathbf{n}$\")\n", |
| 409 | + "\n", |
| 410 | + "for alg, color in linestyle.items():\n", |
| 411 | + " handles.append(Line2D([0], [0], label=alg, color='black', linestyle=color, linewidth=linewidth))\n", |
| 412 | + " labels.append(alg)\n", |
| 413 | + "\n", |
| 414 | + "handles.append(Line2D([0], [0], linestyle=''))\n", |
| 415 | + "labels.append(\"\") \n", |
| 416 | + "handles.append(Line2D([0], [0], linestyle=''))\n", |
| 417 | + "labels.append(\"\") \n", |
| 418 | + "\n", |
| 419 | + "ax.legend(handles, labels, loc='lower left', fancybox=True, shadow=True, fontsize=13, ncol=2)\n", |
| 420 | + "ax.set_ylim(1e-10, 10)\n", |
| 421 | + "ax.grid(which=\"both\", axis=\"both\")\n", |
| 422 | + "ax.set_ylabel(r\"$\\sigma$\")\n", |
| 423 | + "ax.set_xlabel(r\"Evaluations\")\n", |
| 424 | + "plt.tight_layout()\n", |
| 425 | + "plt.savefig(\"figures/1p1_sigma.pdf\")" |
| 426 | + ] |
| 427 | + }, |
| 428 | + { |
| 429 | + "cell_type": "code", |
| 430 | + "execution_count": null, |
| 431 | + "metadata": {}, |
| 432 | + "outputs": [], |
| 433 | + "source": [ |
| 434 | + "samplers = [\n", |
| 435 | + " lambda a: stats.cauchy(2.0, scale=1).rvs(size=a),\n", |
| 436 | + " lambda a: stats.dweibull(2.0, scale=1).rvs(size=a),\n", |
| 437 | + " lambda a: stats.norm().rvs(size=a),\n", |
| 438 | + " lambda a: stats.laplace().rvs(size=a),\n", |
| 439 | + " lambda a: stats.logistic().rvs(size=a),\n", |
| 440 | + " lambda a: stats.uniform().rvs(size=a),\n", |
| 441 | + "]\n", |
| 442 | + "\n", |
| 443 | + "labels=[\"Cauchy\", \"dWeibull\", \"Gaussian\", \"Laplace\", \"Logistic\", \"Uniform\"]\n", |
| 444 | + "\n", |
| 445 | + "def time_sampler(sampler, n = 1_000_000):\n", |
| 446 | + " start = perf_counter()\n", |
| 447 | + " sampler(n)\n", |
| 448 | + " return perf_counter() - start\n", |
| 449 | + "\n", |
| 450 | + "t = []\n", |
| 451 | + "for label, sampler in zip(labels, samplers):\n", |
| 452 | + " times = [time_sampler(sampler) for _ in range(1000)]\n", |
| 453 | + " t.append((label, np.mean(times), np.std(times)))\n", |
| 454 | + " \n", |
| 455 | + "time_data = pl.DataFrame(t, schema=['sampler', 'mean', 'std'], orient='row')" |
| 456 | + ] |
| 457 | + }, |
| 458 | + { |
| 459 | + "cell_type": "code", |
| 460 | + "execution_count": null, |
| 461 | + "metadata": {}, |
| 462 | + "outputs": [], |
| 463 | + "source": [ |
| 464 | + "plt.figure(figsize=(6, 3))\n", |
| 465 | + "p = plt.errorbar(\n", |
| 466 | + " time_data['sampler'], time_data['mean'], time_data['std'], \n", |
| 467 | + " marker='_', markersize=20, \n", |
| 468 | + " capsize=5,\n", |
| 469 | + " markeredgewidth=2, \n", |
| 470 | + " elinewidth=2, \n", |
| 471 | + " linestyle=''\n", |
| 472 | + ")\n", |
| 473 | + "\n", |
| 474 | + "\n", |
| 475 | + "plt.grid()\n", |
| 476 | + "plt.xticks(rotation=25);\n", |
| 477 | + "plt.ylabel(\"Time [s]\", color=p[0].get_color())\n", |
| 478 | + "plt.yticks(color=p[0].get_color())\n", |
| 479 | + "ax1 = plt.gca()\n", |
| 480 | + "ax2 = plt.twinx()\n", |
| 481 | + "g_time = time_data.filter(sampler='Gaussian')['mean'] \n", |
| 482 | + "\n", |
| 483 | + "\n", |
| 484 | + "ax2.plot(time_data['sampler'], (time_data['mean'] - g_time) / g_time, color='red', marker='x', linestyle='', markersize=7, markeredgewidth=2)\n", |
| 485 | + "\n", |
| 486 | + "ax2.set_ylabel(\"vs. Gaussian\", color='red')\n", |
| 487 | + "ax2.plot(time_data['sampler'], np.zeros(6), linestyle='dashed', zorder=-100, color='grey')\n", |
| 488 | + "plt.yticks(color='red')\n", |
| 489 | + "ax2.set_ylim(*((np.array(ax1.get_ylim()) - g_time[0]) / g_time[0]))\n", |
| 490 | + "plt.tight_layout()\n", |
| 491 | + "plt.savefig(\"figures/time.pdf\")" |
238 | 492 | ] |
239 | 493 | } |
240 | 494 | ], |
|
0 commit comments