From 93e9234ce69ef9ba930971afbc2ee39639307ec8 Mon Sep 17 00:00:00 2001 From: Daniel Lukats Date: Mon, 23 Sep 2019 09:56:21 +0200 Subject: [PATCH] Initial commit --- .gitignore | 110 +++++++++++++++++++++++++++++++++++++++++++++++++++ functions.py | 38 ++++++++++++++++++ solvers.py | 48 ++++++++++++++++++++++ test.py | 35 ++++++++++++++++ 4 files changed, 231 insertions(+) create mode 100644 .gitignore create mode 100644 functions.py create mode 100644 solvers.py create mode 100644 test.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aab7ea0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,110 @@ + +# Created by https://www.gitignore.io/api/python +# Edit at https://www.gitignore.io/?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# End of https://www.gitignore.io/api/python diff --git a/functions.py b/functions.py new file mode 100644 index 0000000..d88e781 --- /dev/null +++ b/functions.py @@ -0,0 +1,38 @@ +import numpy as np + +from dataclasses import dataclass +from typing import Callable, Tuple + + +@dataclass +class Function: + xlim: Tuple[float, float] + ylim: Tuple[float, float] + minimum: Tuple[float, float] + eval: Callable + + def grid(self, x_dim=256, y_dim=256): + x = np.linspace(self.xlim[0], self.xlim[1], x_dim) + y = np.linspace(self.ylim[0], self.ylim[1], y_dim) + X, Y = np.meshgrid(x, y) + return x, y, self.eval(X, Y) + +class Rastrigin(Function): + def __init__(self, A: int = 10): + super().__init__( + xlim=(-5.12, 5.12), + ylim=(-5.12, 5.12), + minimum=(0, 0), + eval=lambda x, y: self.A * 2 + \ + (x**2 - self.A * np.cos(2 * np.pi * x)) + \ + (y**2 - self.A * np.cos(2 * np.pi * y))) + self.A = A + +class Sphere(Function): + def __init__(self): + super().__init__( + xlim=(-5, 5), + ylim=(-5, 5), + minimum=(0, 0), + eval=lambda x, y: x**2 + y**2) + diff --git a/solvers.py b/solvers.py new file mode 100644 index 0000000..2ebec2e --- /dev/null +++ b/solvers.py @@ -0,0 +1,48 @@ +import numpy as np + +from abc import ABC, abstractmethod +from functions import Function +from typing import Iterable, Tuple + + +class Solver(ABC): + def __init__(self, function: Function): + self.function = function + + @abstractmethod + def rank(self, samples: Iterable[Iterable[float]], elite_size: int) -> (Iterable[Tuple], Iterable[Tuple], float): + pass + + @abstractmethod + def sample(self, n: int) -> Iterable[float]: + pass + + @abstractmethod + def update(elite: Iterable[Tuple]): + pass + + +class SimpleEvolutionStrategy(Solver): + def __init__(self, function: Function, mu: Iterable[float], sigma: Iterable[float] = None): + if sigma and len(mu) != len(sigma): + raise Exception('Length of mu and sigma must match') + super().__init__(function) + self.mu = mu + if sigma: + self.sigma = sigma + else: + self.sigma = [1] * len(mu) + + def rank(self, samples: Iterable[Iterable[float]], elite_size: int) -> (Iterable[Tuple], Iterable[float]): + fitness = self.function.eval(*samples.transpose()) + samples = samples[np.argsort(fitness)] + fitness = np.sort(fitness) + elite = samples[0:elite_size] + return elite, fitness[0] + + def sample(self, n: int) -> Iterable[Iterable[float]]: + return np.array([np.random.multivariate_normal(self.mu, np.diag(self.sigma)) for _ in range(n)]) + + def update(self, elite: Iterable[Tuple]): + self.mu = elite[0] + diff --git a/test.py b/test.py new file mode 100644 index 0000000..be33ca8 --- /dev/null +++ b/test.py @@ -0,0 +1,35 @@ +import functions +import numpy as np +import solvers + +from matplotlib import pyplot as plt + + +plot_rest = True +f = functions.Rastrigin() +# f = functions.Sphere() +s = solvers.SimpleEvolutionStrategy(f, np.array([2, 2])) +old_fitness = 100 +fitness = old_fitness * 0.9 +old = None +plt.plasma() + +while abs(old_fitness - fitness) > 0.001: + old_fitness = fitness + samples = s.sample(100) + elite, fitness = s.rank(samples, 10) + s.update(elite) + + if plot_rest: + rest = np.setdiff1d(samples, elite, assume_unique=True) + rest = rest.reshape((int(rest.shape[0]/2), 2)) + plt.pcolormesh(*f.grid()) + if plot_rest: + plt.scatter(*rest.transpose(), color="dimgrey") + if old is not None: + plt.scatter(*old.transpose(), color="lightgrey") + plt.scatter(*elite.transpose(), color="yellow") + plt.scatter(*elite[0].transpose(), color="green") + plt.show() + old = elite +