import sys,os
import numpy as np
import time
import torch
import matplotlib.pyplot as plt


jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

# Definition of the search domain
design_space = [
    {'name': 'b1', 'type': 'continuous', 'domain': (0,10)}, 
    {'name': 'b2', 'type': 'continuous', 'domain': (0.1,4)},
    {'name': 'b3', 'type': 'continuous', 'domain': (-4,-0.1)},
    {'name': 'b4', 'type': 'continuous', 'domain': (0.05,1)},
    {'name': 'b5', 'type': 'continuous', 'domain': (0.05,1)}
]

# Creation of the study object with study_id 'sensitivity_analysis'
study = client.create_study(
    design_space=design_space,
    driver="ActiveLearning",
    name="Variance-based sensitivity analysis for parameter reconstruction",
    study_id="sensitivity_analysis"
)
#The vectorial model function of the MGH17 problem
def model(x: torch.Tensor) -> torch.Tensor:
    s = torch.arange(33)
    return x[0] + x[1]*torch.exp(-s*x[3]) + x[2]*torch.exp(-s*x[4])

study.configure(
    max_iter = 100,
    surrogates = [
        #A multi-output Gaussian process that learns the dependence of
        #the model on the design parameters.
        dict(type="GP", name="model_vector", output_dim=33,
             correlate_outputs=False)
    ],
    variables = [
        #The mean of the model vector.
        dict(type="LinearCombination", name="model_average",
             inputs=["model_vector"])
    ],
    objectives = [
        #The objective is to sample the model at positions of maximal
        #uncertainty of the model average.
        dict(type="Explorer", variable="model_average",
             penalize_boundaries=True, min_uncertainty=1e-3)
    ],
)

# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, b1: float, b2: float, b3: float, b4: float, b5: float) -> Observation:

    time.sleep(2) # make objective expensive
    observation = study.new_observation()
    #tensor of design values to reconstruct
    x = torch.tensor([b1, b2, b3, b4, b5])    
    observation.add(model(x).tolist())
    
    return observation

# Run the minimization
study.set_evaluator(evaluate)
study.run()
sobol_indices = study.driver.get_sobol_indices(
    object_type="surrogate",
    name="model_vector",
    max_uncertainty=0.001
)
variances = torch.tensor(sobol_indices["variance"])
sobol_values = torch.tensor(sobol_indices["first_order"])


fig = plt.figure(figsize=(10,5))
fig, (ax1,ax2) = plt.subplots(nrows=2, sharex=True,
                              figsize=(10,5))

for idx, info in enumerate(design_space):
    ax1.plot(sobol_values[idx], ".-", label=info['name'])

    #variance of uniform distribution of the parameter
    var_p = (info['domain'][1]-info['domain'][0])**2/12 
    scaled_var = variances*sobol_values[idx]/var_p
    ax2.plot(scaled_var, ".-", label=info['name'])
    
ax1.set_ylabel("First-order Sobol' index")
ax1.legend()
ax1.grid()

ax2.set_xlabel("Model vector index")
ax2.set_ylabel("Scaled variance")
ax2.grid()
plt.savefig("variance_sobol.svg", transparent=True) 

