Initial commit

This commit is contained in:
Daniel Lukats
2019-09-23 09:56:21 +02:00
commit 93e9234ce6
4 changed files with 231 additions and 0 deletions

110
.gitignore vendored Normal file
View File

@@ -0,0 +1,110 @@
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# End of https://www.gitignore.io/api/python

38
functions.py Normal file
View File

@@ -0,0 +1,38 @@
import numpy as np
from dataclasses import dataclass
from typing import Callable, Tuple
@dataclass
class Function:
xlim: Tuple[float, float]
ylim: Tuple[float, float]
minimum: Tuple[float, float]
eval: Callable
def grid(self, x_dim=256, y_dim=256):
x = np.linspace(self.xlim[0], self.xlim[1], x_dim)
y = np.linspace(self.ylim[0], self.ylim[1], y_dim)
X, Y = np.meshgrid(x, y)
return x, y, self.eval(X, Y)
class Rastrigin(Function):
def __init__(self, A: int = 10):
super().__init__(
xlim=(-5.12, 5.12),
ylim=(-5.12, 5.12),
minimum=(0, 0),
eval=lambda x, y: self.A * 2 + \
(x**2 - self.A * np.cos(2 * np.pi * x)) + \
(y**2 - self.A * np.cos(2 * np.pi * y)))
self.A = A
class Sphere(Function):
def __init__(self):
super().__init__(
xlim=(-5, 5),
ylim=(-5, 5),
minimum=(0, 0),
eval=lambda x, y: x**2 + y**2)

48
solvers.py Normal file
View File

@@ -0,0 +1,48 @@
import numpy as np
from abc import ABC, abstractmethod
from functions import Function
from typing import Iterable, Tuple
class Solver(ABC):
def __init__(self, function: Function):
self.function = function
@abstractmethod
def rank(self, samples: Iterable[Iterable[float]], elite_size: int) -> (Iterable[Tuple], Iterable[Tuple], float):
pass
@abstractmethod
def sample(self, n: int) -> Iterable[float]:
pass
@abstractmethod
def update(elite: Iterable[Tuple]):
pass
class SimpleEvolutionStrategy(Solver):
def __init__(self, function: Function, mu: Iterable[float], sigma: Iterable[float] = None):
if sigma and len(mu) != len(sigma):
raise Exception('Length of mu and sigma must match')
super().__init__(function)
self.mu = mu
if sigma:
self.sigma = sigma
else:
self.sigma = [1] * len(mu)
def rank(self, samples: Iterable[Iterable[float]], elite_size: int) -> (Iterable[Tuple], Iterable[float]):
fitness = self.function.eval(*samples.transpose())
samples = samples[np.argsort(fitness)]
fitness = np.sort(fitness)
elite = samples[0:elite_size]
return elite, fitness[0]
def sample(self, n: int) -> Iterable[Iterable[float]]:
return np.array([np.random.multivariate_normal(self.mu, np.diag(self.sigma)) for _ in range(n)])
def update(self, elite: Iterable[Tuple]):
self.mu = elite[0]

35
test.py Normal file
View File

@@ -0,0 +1,35 @@
import functions
import numpy as np
import solvers
from matplotlib import pyplot as plt
plot_rest = True
f = functions.Rastrigin()
# f = functions.Sphere()
s = solvers.SimpleEvolutionStrategy(f, np.array([2, 2]))
old_fitness = 100
fitness = old_fitness * 0.9
old = None
plt.plasma()
while abs(old_fitness - fitness) > 0.001:
old_fitness = fitness
samples = s.sample(100)
elite, fitness = s.rank(samples, 10)
s.update(elite)
if plot_rest:
rest = np.setdiff1d(samples, elite, assume_unique=True)
rest = rest.reshape((int(rest.shape[0]/2), 2))
plt.pcolormesh(*f.grid())
if plot_rest:
plt.scatter(*rest.transpose(), color="dimgrey")
if old is not None:
plt.scatter(*old.transpose(), color="lightgrey")
plt.scatter(*elite.transpose(), color="yellow")
plt.scatter(*elite[0].transpose(), color="green")
plt.show()
old = elite