import sys,os
import numpy as np
import time

jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

#Amplitude of a driven harmonic oscillator. 
def amplitude(omega: float, F: float, omega0: float, gamma: float) -> float:
    return F/np.sqrt((2*omega*omega0*gamma)**2 + (omega0**2 - omega**2)**2)

# Definition of the search domain to tune the oscillator
design_space = [
    {'name': 'F', 'type': 'continuous', 'domain': (0.1, 40.0)}, 
    {'name': 'omega0', 'type': 'continuous', 'domain': (1.0, 4.0)},
    {'name': 'gamma', 'type': 'continuous', 'domain': (0.03, 0.3)},
]

#The amplitude is scaned for 10 different driving frequencies
omegas = np.linspace(1,3,10)

#Since the system depends on the driving frequency which is not a design parameter,
#it is defined as an environment parameter.
environment = [
    {'name': 'omega', 'type': 'variable', 'domain': (omegas[0], omegas[-1])},     
]

# Creation of the study object with study_id 'harmonic_oscillator_fit'
study = client.create_study(
    design_space=design_space,
    environment=environment,
    driver="ActiveLearning",
    name="Optimization of resonant system based on Gaussian fit",
    study_id="harmonic_oscillator_fit"
)

#configuration of study
study.configure(
    max_iter = 30,
    surrogates = [
        #A single-output Gaussian process learns the dependence of the amplitude on the
        #design and environment parameters.
        dict(type="GP", name="amplitude", output_dim=1)
    ],    
    variables = [
        #The single-output Gaussian process is scanned over all omegas
        dict(type="Scan", name="amplitude_scan", output_dim=len(omegas),
             scan_parameters=["omega"], scan_values=omegas[:,None].tolist(),
             input_surrogate="amplitude"),
        #The result of the omega-scan is fitted to a Gaussian + linear expression with
        #amplitude A, resonance frequency tau, linear gradient B, and constant offset C.
        #The output are the fitted values and the mean-squared error (MSE).
        dict(type="Fit", name="fit", input="amplitude_scan", 
             expression="A*exp(-0.5*(omega-tau)^2/sigma^2) + B*omega + C", 
             output_names=["A", "B", "C", "tau", "sigma", "MSE"],
             output_dim=6, 
             model_variables=["omega"], variable_values=omegas[:,None].tolist(),
             initial_parameters=[1.0, 0.0, 0.0, 1.0, 0.5]
            ),  
        #The expression that defines the loss function:
        dict(type="Expression", name="loss",
             expression="(A-20)^2 + (tau - 2)^2 + sigma^2 + 0.001*MSE")
    ],
    objectives = [
        #The only objective of the study is to minimize the loss.
        dict(type="Minimizer", variable="loss"),
    ],
    acquisition_optimizer = dict(
        #Sample computation based on fits is more expensive. In this case
        #advanced sample computation is usually too laborious.
        compute_suggestion_in_advance = False
    ),    
)

# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, F: float, omega0: float, gamma: float) -> Observation:
    #The harmonic-oscillator amplitude is evaluated for all omega values and the
    #observed values are used as training input for the Gaussian process "amplitude"
    observation = study.new_observation()
    for omega in omegas: 
        observation.add(amplitude(omega, F, omega0, gamma),
                        environment_value=[omega], model_name="amplitude")    
    return observation

# Run the minimization
study.set_evaluator(evaluate)
study.run()

# Print result
best = study.get_state("driver.best_sample")
print(f"Best design parameters: F={best['F']:.3f}, omega0={best['omega0']:.3f}, "
      f"gamma={best['gamma']:.3f}")
print("Objective: (A-20)^2 + (tau - 2)^2 + sigma^2 + 0.001*MSE = "
      f"{study.get_state('driver.min_objective'):.3f}")