import sys,os
import numpy as np
import time
import matplotlib.pyplot as plt

jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

#Rastrigin-like function depending on additional phase offset phi 
def rast(x1: float, x2:float, phi:float) -> float:
    return (10*2 + x1**2 + x2**2 
            - 10*np.cos(2*np.pi*x1 + phi)  
            - 10*np.cos(2*np.pi*x2)
           )

#time-dependent slowly varying phi
def current_phi() -> float:
    return 2*np.pi*np.sin(time.time()/180)

# Definition of the search domain
design_space = [
    {'name': 'x1', 'type': 'continuous', 'domain': (-1.5, 1.5)}, 
    {'name': 'x2', 'type': 'continuous', 'domain': (-1.5, 1.5)},
]

# Definition of the environment variable "phi"
environment = [
    {'name': 'phi', 'type': 'variable', 'domain': (-2*np.pi, 2*np.pi)},
]

# Creation of the study object with study_id 'changing_environment'
study = client.create_study(
    design_space=design_space,
    environment=environment,
    driver="ActiveLearning",
    name="Optimal control of a system in a changing environment",
    study_id="changing_environment"
)

#In the initial training phase, the target is to explore the
#parameter space to find the global minimim.
study.configure(
    #train with 500 data points    
    max_iter=500,
    #Advanced sample computation is switched off since the environment
    #parameter phi can change significantly between computation
    #of the suggestion and evaluation of the objective function
    acquisition_optimizer={'compute_suggestion_in_advance': False}
)

# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, x1: float, x2: float) -> Observation:
    time.sleep(2) # make objective expensive
    observation = study.new_observation()
    #get current phi
    phi = current_phi()
    observation.add(rast(x1, x2, phi), environment_value=[phi])
    return observation

# Run the minimization
study.set_evaluator(evaluate)
study.run()

#The target in the control phase is to evaluate the offet Rastrigin function only
#at well performing (x1,x2)-point depending on the current value of the environment.
MAX_ITER = 500 #evaluate for 500 additional iterations
study.configure(
    max_iter=500 + MAX_ITER,
    #The scaling is reduced to penalize parameters with large uncertainty    
    scaling=0.01,
    #The lower-confidence bound (LCB) strategy is chosen instead of the
    #default expected improvement (EI). LCB is easier to maximize at the
    #risk of less exploration of the parameter space, which is anyhow not
    #desired in the control phase.
    objectives =[
        {'type': 'Minimizer', 'name': 'objective', 'strategy': 'LCB'}
    ],
    acquisition_optimizer={'compute_suggestion_in_advance': False}
)


#keep track of suggested design points and phis at request time and evaluation time
design_points: list[list[float]] = []
phis_at_request: list[list[float]] = []
phis_at_eval: list[list[float]] = []
    
iter = 0    
while not study.is_done():
    iter += 1
    if iter > MAX_ITER: break
        
    phi = current_phi()    
    suggestion = study.get_suggestion(environment_value=[phi])
    phis_at_request.append(phi)    
    kwargs = suggestion.kwargs
    design_points.append((kwargs["x1"], kwargs["x2"]))
    try:
        obs = evaluate(study=study, **kwargs)
        #update phi from observation
        phi = obs.data[None][0]["env"][0]
        phis_at_eval.append(phi)    
        
        predictions = study.driver.predict(
            points=[(kwargs["x1"], kwargs["x2"], phi)]
        )
        std = np.sqrt(predictions["variance"][0][0])

        print(f"Uncertainty of prediction {std}")
        #add data only if prediction has significant uncertainty
        if std > 0.01:
            study.add_observation(obs, suggestion.id)
        else:
            study.clear_suggestion(
                suggestion.id, f"Ignoring observation with uncertainty {std}"
            )
    except Exception as err:
        study.clear_suggestion(
            suggestion.id, f"Evaluator function failed with error: {err}"
        )
        raise


fig = plt.figure(figsize=(10,5))

#all observed training samples
observed = study.driver.get_observed_values()
plt.subplot(1, 2, 1)
plt.plot(observed["means"],".")
plt.axvline(x=500, ls='--', color = 'gray')
plt.xlabel("training+control iteration")
plt.ylabel("observed value of Rastrigin function")

#observed values during control phase
observed_vals = [
    rast(p[0], p[1], phi) for p, phi in zip(design_points, phis_at_eval)
]

#values that would have been observed at request time,
#i.e. if there would be no time delay between request and 
#evaluation of suggestion
observed_vals_at_request = [
    rast(p[0], p[1], phi) for p, phi in zip(design_points, phis_at_request)
]

#best value of x1-parameter depending on environment
def best_x1(phi: float) -> float:
    return -phi/(2*np.pi) + (np.sign(phi) if np.abs(phi) > np.pi else 0.0)

#best possible values 
best_vals = [rast(best_x1(phi), 0, phi) for phi in phis_at_eval]

plt.subplot(1, 2, 2)
plt.plot(observed_vals,".", label="observed values")
plt.plot(observed_vals_at_request,".", label="observed values if no time delay")
plt.plot(best_vals, label="smallest possible values")
plt.ylim(1e-4, 1e1)
plt.yscale("log")
plt.xlabel("control iteration")
plt.legend()
plt.savefig("training_and_control.svg", transparent=True) 
