import sys,os
import numpy as np
import time
import numpy as np
import matplotlib.pyplot as plt


jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

# Definition of the search domain
design_space = [
    {'name': 'x1', 'type': 'continuous', 'domain': (1,4)}, 
    {'name': 'x2', 'type': 'continuous', 'domain': (-1.5,1.5)},    
]

# Creation of the study object with study_id 'multi_objective_optimization'
study = client.create_study(
    design_space=design_space,
    driver="ActiveLearning",
    name="Multi-objective optimization",
    study_id="multi_objective_optimization"
)
#Lower and upper reference point for hypervolume definition (see figure)
lower_ref=[0, 0]
upper_ref=[5, 50]

study.configure(
    max_iter = 50,
    surrogates = [
        #A multi-output Gaussian process that learns the dependence of
        #the vectorial model on the design parameters.
        dict(type="GP", name="model_vector", output_dim=2,
             correlate_outputs=False)
    ],
    variables = [
        #Selectors for the values of f1 and f2
        dict(type="SingleSelector", name="f1",
             input="model_vector", select_by_name="model_vector0"),
        dict(type="SingleSelector", name="f2",
             input="model_vector", select_by_name="model_vector1")
    ],
    objectives = [
        #Multi-minimization objective for f1, f2
        dict(type="MultiMinimizer", variables=["f1", "f2"], name="objective",
             lower_reference=lower_ref, upper_reference=upper_ref)
    ]
)

# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, x1: float, x2: float) -> Observation:

    time.sleep(2) # make objective expensive
    observation = study.new_observation()
    #determine list of the two objective values
    g = 20 + (x2**2 - 10*np.cos(2*np.pi*x2))
    observation.add([x1, g/x1])
    
    return observation

# Run the minimization
study.set_evaluator(evaluate)
study.run()

plt.figure(figsize=(10,5))

#Plot all observations of (f1, f2)
f1 = study.driver.get_observed_values("variable", "f1")
f2 = study.driver.get_observed_values("variable", "f2")
plt.plot(f1["means"], f2["means"], "o", label="Donimated observations")

#Plot the estimate of the Paretro front, i.e. nondominated observations
pf = study.driver.get_state("pareto_front")
nondom_f1, nondom_f2 = np.array(pf["f1"]), np.array(pf["f2"])
plt.plot(nondom_f1, nondom_f2, "o", label="Pareto front estimate")

#Plot analytic Pareto front
X = np.linspace(1, 4, 100)
plt.plot(X, 10/X, "k--", label="Analytic Pareto front")

#Plot nondominated yypervolume
plt.plot(*lower_ref, "rx", label="Lower reference")
plt.plot(*upper_ref, "bx", label="Upper reference")
nondom_f2_aug = np.hstack((upper_ref[1], nondom_f2))
X = [lower_ref[0]] + np.repeat(nondom_f1,2).tolist() + [upper_ref[0]]
nondom_lower = [lower_ref[1]]*len(X)
nondom_upper = np.repeat(nondom_f2_aug, 2).tolist()
plt.fill_between(X, nondom_lower, nondom_upper, alpha=0.2, color="blue",
                 label="Nondominated hypervolume")

plt.xlabel("f1")
plt.ylabel("f2")
plt.legend(loc="upper center")
plt.grid()
plt.savefig("multi_objective_minimization.svg", transparent=True)
