import sys,os
import numpy as np
import time
import matplotlib.pyplot as plt

jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

def Atrans(
        R1: float, R2: float, l: float, alpha: float, omega: float
) -> np.ndarray:
    """Transmission through the etalon
    Args:
       R1: Reflectivity of first mirror
       R2: Reflectivity of second mirror
       l: resonator length 
       alpha: Intensity loss coefficient
       omega: Angular frequency of light
    """
    loss = np.exp(-alpha*l)
    R = np.sqrt(R1*R2)
    
    out = (1 - R1)*(1 - R2)*loss 
    out /= (1 - R*loss)**2 + 4*R*loss*np.sin(omega*l)**2
    return out

# Definition of the parameter domain
design_space = [
    {'name': 'R1', 'type': 'continuous', 'domain': (0.1, 0.7)}, 
    {'name': 'R2', 'type': 'continuous', 'domain': (0.1, 0.7)}, 
    {'name': 'l', 'type': 'continuous', 'domain': (0.5, 1.0)}, 
]

# Definition of the fixed environment variable alpha and the
# the scan variable omega
environment = [
    {'name': 'alpha', 'type': 'fixed', 'domain': 0.05},
    {'name': 'omega', 'type': 'variable', 'domain': (0, 2*np.pi)},
]
#The omega-scan defining the transmission spectra
omegas = np.linspace(0, 2*np.pi, 50)

# Creation of the study object with study_id 'active_surrogate_training'
study = client.create_study(
    design_space=design_space,
    environment=environment,
    driver="ActiveLearning",
    name="Active learning of a global surrogate model",
    study_id="active_surrogate_training"
)

study.configure(
    max_iter = 50,
    surrogates=[
        # We use a neural network with 4 hidden layers of 200 neurons each
        # to learn the scalar function Atrans(R1, R2, l, omega)
        dict(
            type="NN", name="Atrans", output_dim=1,
            hidden_layers_arch=[200, 200, 200, 200],
            num_NNs=60,
            optimization_step_max=-1,
            trainer=dict(
                type="full_data_trainer",
                num_epochs=1000,
                num_expel_NNs=30
            )
        )
    ],
    variables=[
        # The variable defines a scan of the surrogate prediction over all omega values
        dict(
            type="Scan",
            name="omega_scan",
            input_surrogate="Atrans",
            output_dim=len(omegas),
            scan_parameters=["omega"],
            scan_values=omegas[:, None].tolist(),
        ),
        # The variable defines the average transmission of the omega-scan
        dict(type="LinearCombination", name="average", inputs=["omega_scan"]),
    ],
    objectives=[
        # The objective is to evaluate the model function at maximal uncertainty of
        # the average transmission.
        dict(
            type="Explorer",
            name="objective",
            variable="average",
        )
    ]
)

# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, R1: float, R2: float, l: float, alpha: float) -> Observation:
    time.sleep(2) # make objective expensive
    observation = study.new_observation()
    for omega in omegas:
        observation.add(
            Atrans(R1, R2, l, alpha, omega),
            environment_value=[omega],
            model_name="Atrans",
        )
    return observation

# Run the training loop
study.set_evaluator(evaluate)
study.run()



study.configure(
    surrogates=[
        # For making more accurate predictions, we train the network on
        # all data for 1500 epochs.
        dict(
            type="NN", name="Atrans", output_dim=1,
            hidden_layers_arch=[200, 200, 200, 200],
            num_NNs=60,
            trainer=dict(
                type="full_data_trainer",
                num_epochs=1500,
                num_expel_NNs=30
            )
        ),
    ],
)

# Get prediction and anayltic values on a finer resolved omega-scan
omegas_fine = np.linspace(0, 2 * np.pi, 150)

# To test the worst-case prediction, we get a suggestion corresponding to
# a sample with largest uncertainty
s = study.get_suggestion()
study.clear_suggestion(s.id)

plt.figure(figsize=(10, 5))
for R1, R2, l in [
    (s.kwargs["R1"], s.kwargs["R2"], s.kwargs["l"]),
    (0.1, 0.1, 0.5),
    (0.1, 0.7, 0.75),
    (0.7, 0.7, 1.0),
]:
    prediction = study.driver.predict(
        points=[[R1, R2, l, omega] for omega in omegas_fine],
        object_type="surrogate",
        name="Atrans",
    )
    mean = np.array(prediction["mean"]).squeeze()
    std = np.sqrt(np.array(prediction["variance"])).squeeze()
    p = plt.plot(omegas_fine,  mean)
    plt.fill_between(
        omegas_fine, mean - std, mean + std, alpha=0.2, color=p[0].get_color(),
    )
    plt.plot(
        omegas_fine,
        [Atrans(R1, R2, l, 0.05, omega) for omega in omegas_fine],
        "--",
        color=p[0].get_color(),
    )

plt.xlabel("Angular frequency")
plt.ylabel("Transmission")
plt.grid()
plt.savefig("etalon_predictions.svg", transparent=True)
