import random

MAX_GEN = 100
POP_SIZE = 50
IND_LEN = 25
MUT_FLIP_PROB = 1/IND_LEN
CROSS_PROB = 0.8
MUT_PROB = 0.2

def selection(pop, fits):
    return random.choices(pop, fits, k=POP_SIZE)

def fitness(ind):
    return sum(ind)

def cross(p1, p2):
    point = random.randrange(1, IND_LEN)
    o1 = p1[:point] + p2[point:]
    o2 = p2[:point] + p1[point:]
    return o1, o2

def crossover(pop):
    o = []
    for p1, p2 in zip(pop[::2], pop[1::2]):
        o1, o2 = p1[:], p2[:]
        if random.random() < CROSS_PROB:
            o1, o2 = cross(p1, p2)
        o.extend([o1, o2])
    return o

def mutation(pop):
    return [mutate(ind) if random.random() < MUT_PROB else ind[:] for ind in pop]

def mutate(ind):
    return [1 - v if random.random() < MUT_FLIP_PROB else v for v in ind]

def evolutionary_algorithm(pop, elitism=False):
    log = []
    for _ in range(MAX_GEN):
        fits = [fitness(ind) for ind in pop]
        log.append(max(fits))
        mating = selection(pop, fits)
        o = crossover(mating)
        off = mutation(o)
        if elitism:
            off[0] = max(pop, key=fitness)
        pop = off[:]
    return pop, log

def random_initial_population():
    return [random_individual() for _ in range(POP_SIZE)]

def random_individual():
    return [0 if random.random() < 0.5 else 1 for _ in range(IND_LEN)]

p1 = random_individual()
p2 = random_individual()

o1, o2 = cross(p1, p2)

o = mutate(o2)

print(f'{p1=}')
print(f'{p2=}')
print(f'{o1=}')
print(f'{o2=}')
print(f' {o=}')


pop = random_initial_population()
init_fit = max(fitness(ind) for ind in pop)

logs = []
for _ in range(10):
    final_pop, log = evolutionary_algorithm(pop, elitism=True)
    logs.append(log)

logs_el = []
for _ in range(10):
    final_pop, log = evolutionary_algorithm(pop, elitism=True)
    logs_el.append(log)

final_fit = max(fitness(ind) for ind in final_pop)

print(f'{init_fit=}')
print(f'{final_fit=}')

import matplotlib.pyplot as plt
import numpy as np

logs = np.array(logs)

plt.plot(logs.mean(axis=0))
plt.fill_between(list(range(MAX_GEN)), 
                 np.percentile(logs, axis=0, q=25),  
                 np.percentile(logs, axis=0, q=75), 
                 alpha = 0.5)

logs = np.array(logs_el)

plt.plot(logs.mean(axis=0))
plt.fill_between(list(range(MAX_GEN)), 
                 np.percentile(logs, axis=0, q=25),  
                 np.percentile(logs, axis=0, q=75), 
                 alpha = 0.5)

plt.show()