import sys,os
import numpy as np
import time
import torch
import numpy as np
import scipy
import requests
from dispersion import Spectrum, Material # run pip install dispersion if not installed
import tmm_fast # run pip install tmm-fast if not installed
import matplotlib.pyplot as plt
from matplotlib import patches

torch.set_default_dtype(torch.float64)


from jcmoptimizer import Client, Study, Observation
client = Client()


K = 50 # number spectral sampling points
K_short = K_long = int(K/2) #samples in short and long wavelength range
wl = torch.linspace(800, 1800, K) * 1e-9 #Wavelength scan
N = 4 # number N in layer stack S(AS)^N
D = 2*N + 1 #dimensionality of search space
theta = torch.tensor([0.0]) #angle of incidence
num_layers = 2*N + 1 + 2 #total number of layers including air
num_stacks = 1 #only one stack is simulated in each iteration

#Load material data refractiveindex.info
if not os.path.exists("SiO2.yml"):
    r = requests.get("https://refractiveindex.info/database/data/main/SiO2/nk/Malitson.yml")
    open("SiO2.yml", 'wb').write(r.content)
if not os.path.exists("Ag.yml"):
    r = requests.get("https://refractiveindex.info/database/data/main/Ag/nk/Johnson.yml")
    open("Ag.yml", 'wb').write(r.content)
    
#create stack of refractive indices S(AS)^N
M = torch.full((num_stacks, num_layers, wl.shape[0]), 1+0j)
M[0,0,:] = M[0,-1,:] = 1.0 #first and last layer is air
SiO2 = Material(file_path="SiO2.yml", unit="um")
Ag = Material(file_path="Ag.yml", unit="um")
spm = Spectrum(wl.numpy(), spectrum_type='Wavelength', unit='meter')
#Set n-k data for each layer and wavelength 
M[0,1:-1:2,:] = torch.tensor(SiO2.get_nk_data(spm))
M[0,2:-1:2,:] = torch.tensor(Ag.get_nk_data(spm))

#design space (units nm)
design_space_SiO2 = [{'name': f'sio2_{i+1}', 'domain': (100, 500)} for i in range(N+1)]
design_space_Ag = [{'name': f'ag_{i+1}', 'domain': (2, 8)} for i in range(N)]
design_space = design_space_SiO2 + design_space_Ag
param_names = [d['name'] for d in design_space]
# Creation of the study object with study_id 'physics_informed_bayesian_optimization'
study = client.create_study(
    design_space=design_space,
    driver="ActiveLearning",
    study_name="Physics-informed Bayesian optimization",
    study_id="physics_informed_bayesian_optimization"
)
study.configure(
    max_iter=30,
    surrogates=[
        dict(type='GP', name="model", output_dim=K, 
             # To speed things up, we run a single hyperparameter optimization after 
             # 5 iterations (= 5*(D+1) observations of function values and derivatives) 
             max_optimization_interval=5*(D+1),
             optimization_step_max=5*(D+1),
             covariance_matrix=dict(max_data_hyper_derivs=5*(D+1))
        )
    ],
    variables=[
        # Selector of transmissions in short wangelength range
        dict(type="MultiSelector", name="trans_short", input="model", 
             select_range=(0, K_short-1), output_dim=K_short),
        # Selector of transmissions in long wangelength range
        dict(type="MultiSelector", name="trans_long", input="model", 
             select_range=(K_short, K-1), output_dim=K_long),
        # Mean transmission in short wangelength range
        dict(type="LinearCombination", name="avg_short", inputs=["trans_short"]),
        # Mean transmission in long wangelength range
        dict(type="LinearCombination", name="avg_long", inputs=["trans_long"]),
        # Chi^2 = sum of squared transmission sum(t_i^2) in short wavelength range
        dict(type="ChiSquaredValue", name="chi_squared_short", input="trans_short"),
        # Loss: variance of short wavelengths + squared average of long wavelengths
        dict(type="Expression", name="loss", 
             expression=f"sqrt(max(0, chi_squared_short/{K_short} - avg_short^2)) + avg_long"
        )
    ],
    objectives=[
        dict(type="Minimizer", variable="loss", name="objective"),
        dict(type="Constrainer", variable="avg_short", lower_bound=0.4,
             name="min_transmission_short"),
    ],
    acquisition_optimizer=dict(
        num_training_samples=5, compute_suggestion_in_advance=False
    )
)
# Transmission spectrum for thickness tensor D
def transmissions(D: torch.Tensor) -> torch.Tensor:
    T = torch.zeros((1,num_layers))    
    T[0, 1:-1:2] = D[:N+1] # SiO2 thicknesses
    T[0, 2:-1:2] = D[N+1:] # Ag thicknesses
    T[0, 0] = T[0, -1] = np.inf
    # Get solution using transfer-matrix method
    output = tmm_fast.coh_tmm('s', M, T * 1e-9, theta, wl, device='cpu')    
    return output["T"][0,0,:]
    
# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, **kwargs: float) -> Observation:
    observation = study.new_observation()
    D = torch.tensor([kwargs[param_names[i]] for i in range(2*N + 1)])
    T = transmissions(D)
    Jac = torch.autograd.functional.jacobian(transmissions, D)
    observation.add(T.tolist())
    for idx, param in enumerate(param_names):
        observation.add(Jac[:,idx].tolist(), derivative=param)    
    return observation

# Run the minimization
study.set_evaluator(evaluate)
study.run()
kwargs = study.get_state("driver.best_sample")
D = torch.tensor([kwargs[param_names[i]] for i in range(2*N + 1)])
T = transmissions(D)
fig, ax = plt.subplots(nrows=2, height_ratios=[3,1], figsize=(8, 4))

ax[0].vlines(1300, 0, 0.8, colors="black", linestyles="dashed",
             label=r"$\lambda_{\rm gap}=1300$ nm")
ax[0].plot(wl*1e9, T, ".-", label="spectral response")
ax[0].set_xlabel(r"$\lambda [nm]$")
ax[0].legend()
ax[0].grid()

T = [0]*(2*N + 1)    
D = [kwargs[param_names[i]] for i in range(2*N + 1)]
T[0::2] = D[:N+1]
T[1::2] = D[N+1:]
for idx, t in enumerate(T):
    rect = patches.Rectangle(
        xy=(wl[0] + sum(T[:idx]),0), width=T[idx], height=1, 
        facecolor="lightblue" if idx % 2==0 else "white"
    )
    ax[1].add_patch(rect)
    ax[1].set_xlim(0, sum(T))
    ax[1].text(y=0.5, x=sum(T[:idx]) + 0.5*T[idx], 
               s="SiO$_2$" if idx % 2==0 else "Ag", 
               horizontalalignment='center', verticalalignment='center')
    ax[1].set_xlabel("Layer Stack Thicknesses [nm]")
    ax[1].get_yaxis().set_visible(False)
plt.tight_layout()
plt.show()
fig.savefig("optimized_filter.svg", transparent=True)

# As a comparison, we run a global differential evolution (DE)
# and save evaluated objective values.
Y_DE: list[float] = []

def objective(x: np.ndarray) -> float:
    T = transmissions(torch.tensor(x))
    T_short = T[:K_short]
    T_long = T[K_short:]    
    loss = T_short.std() + T_long.mean()
    # Penalized loss is zero if constraint is not fulfilled, otherwise it's negative.
    penalized_loss = (loss - 1) * torch.heaviside(
        T_short.mean() - 0.4, values = torch.tensor(0.0)
    )
    Y_DE.append(penalized_loss.item() + 1)
    return penalized_loss.item()

scipy.optimize.differential_evolution(
    objective,
    bounds=[info["domain"] for info in design_space],
    polish=False, # Run standard DE without L-BFGS-B
    popsize=6, # Generation size = 6 * D = 54
    maxiter=100 # max 100 generations -> 100 * 54 = 5400 evaluations
)

Y_BO = study.driver.get_observed_values("variable", "loss")
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(Y_DE, ".", color="steelblue", label="Differential evloution")
ax.plot(np.minimum.accumulate(Y_DE), "-", color="steelblue")
ax.plot(Y_BO["means"], ".", color="darkorange", label="Physics-informed BO")
ax.plot(np.minimum.accumulate(Y_BO["means"] + [min(Y_BO["means"])]*3500),
        "-", color="darkorange")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("# Iterations")
ax.set_ylabel("loss")
ax.legend()
ax.grid()
plt.show()
fig.savefig("comparison_PIBO_DE.svg", transparent=True)


client.shutdown_server()
