import sys,os
import numpy as np
import time
import matplotlib.pyplot as plt

jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

# Definition of the search domain
design_space = [
    {'name': 'x1', 'type': 'continuous', 'domain': (-1.5,1.5)}, 
    {'name': 'x2', 'type': 'continuous', 'domain': (-1.5,1.5)},
]

# Definition of fixed environment parameter
environment = [
    {'name': 'radius', 'type': 'fixed', 'domain': 1.5},
]

# Definition of a constraint on the search domain
constraints = [
    {'name': 'circle', 'expression': 'sqrt(x1^2 + x2^2) <= radius'}
]

#Creation of studies to benchmark against each other
studies: dict[str, Study] = {}
drivers: list[str] = ["BayesianOptimization", "ActiveLearning", "CMAES", "DifferentialEvolution", "ScipyMinimizer"]
for driver in drivers:
    studies[driver] = client.create_study(
        design_space=design_space,
        environment=environment,
        constraints=constraints,
        driver=driver,
        name=driver,
        study_id=f"benchmark_{driver}",
        open_browser=False
    )

#Configuration of all studies
config_kwargs = dict(max_iter=250, num_parallel=2)
min_val = 1e-3 #Stop study when this value was observed
for driver, study in studies.items():
    if driver=="ActiveLearning":
        study.configure(
            surrogates=[dict(type="NN")],
            objectives=[dict(type="Minimizer", min_val=min_val)],
            **config_kwargs
        )
    elif driver=="ScipyMinimizer":
        #For a more global search, 3 initial gradient-free Nelder-Mead are started
        study.configure(method="Nelder-Mead", num_initial=3,
                        min_val=min_val, **config_kwargs)
    else:
        study.configure(min_val=min_val, **config_kwargs)
# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, x1: float, x2: float, radius: float) -> Observation:

    time.sleep(2) # make objective expensive
    observation = study.new_observation()
    observation.add(10*2
                + (x1**2-10*np.cos(2*np.pi*x1)) 
                + (x2**2-10*np.cos(2*np.pi*x2))
            )
    return observation

# Creation of a benchmark with 6 repetitions and add all 4 studies
benchmark = client.create_benchmark(num_average=6)
for study in studies.values():
    benchmark.add_study(study)

# Run the benchmark - this will take a while
benchmark.set_evaluator(evaluate)
benchmark.run()

# Plot cummin convergence w.r.t. number of evaluations and time
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True, figsize=(10, 5))
plt.rc("font", family="serif")
plt.subplots_adjust(wspace=0.0)

data = benchmark.get_data(x_type="num_evaluations")
for X, Y, sdev in zip(data["X"], data["Y"], data["sdev"]):
    std_error = np.array(sdev) / np.sqrt(6)
    p = ax1.plot(X, Y, linewidth=2.0)
    ax1.fill_between(X, Y - std_error, Y + std_error, alpha=0.2, color=p[0].get_color())
ax1.grid()
ax1.set_xlabel("Number of Evaluations", fontsize=12)
ax1.set_ylabel("Average Cummulative Minimum", fontsize=12)
ax1.set_ylim(-0.4, 10)

data = benchmark.get_data(x_type="time")
for name, X, Y, sdev in zip(data["names"], data["X"], data["Y"], data["sdev"]):
    std_error = np.array(sdev) / np.sqrt(6)
    p = ax2.plot(X, Y, linewidth=2.0, label=name)
    ax2.fill_between(X, Y - std_error, Y + std_error, alpha=0.2, color=p[0].get_color())
ax2.legend()
ax2.grid()
ax2.set_xlabel("Time (sec)", fontsize=12)
ax2.set_ylim(-0.4, 10)
plt.savefig("benchmark.svg", transparent=True)

