import sys,os
import numpy as np
import time
import pandas as pd
import torch
import matplotlib.pyplot as plt
import corner #run "pip install corner" if not installed
import emcee #run "pip install emcee" if not installed


jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

# Definition of the search domain
design_space = [
    {'name': 'b1', 'type': 'continuous', 'domain': (0,10)}, 
    {'name': 'b2', 'type': 'continuous', 'domain': (0.1,4)},
    {'name': 'b3', 'type': 'continuous', 'domain': (-4,-0.1)},
    {'name': 'b4', 'type': 'continuous', 'domain': (0.05,1)},
    {'name': 'b5', 'type': 'continuous', 'domain': (0.05,1)}
]
constraints = [
    {'name': 'test', 'expression': 'b2 + b3 <= 1.0'}
]

# Creation of the study object with study_id 'bayesian_least_squares'
study = client.create_study(
    design_space=design_space,
    constraints=constraints,
    driver="BayesianLeastSquares",
    name="Solution of least-square problem using Bayesian optimization",
    study_id="bayesian_least_squares"
)
#The vectorial model function of the MGH17 problem
def model(x: torch.Tensor) -> torch.Tensor:
    s = torch.arange(33)
    return x[0] + x[1]*torch.exp(-s*x[3]) + x[2]*torch.exp(-s*x[4])

#Target vector of the MGH17
target=torch.tensor([
    8.44E-01, 9.08E-01, 9.32E-01, 9.36E-01, 9.25E-01,
    9.08E-01, 8.81E-01, 8.50E-01, 8.18E-01, 7.84E-01,
    7.51E-01, 7.18E-01, 6.85E-01, 6.58E-01, 6.28E-01,
    6.03E-01, 5.80E-01, 5.58E-01, 5.38E-01, 5.22E-01,
    5.06E-01, 4.90E-01, 4.78E-01, 4.67E-01, 4.57E-01,
    4.48E-01, 4.38E-01, 4.31E-01, 4.24E-01, 4.20E-01,
    4.14E-01, 4.11E-01, 4.06E-01
])

study.configure(
    target_vector=target.tolist(),
    max_iter=120,
)
# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, b1: float, b2: float, b3: float, b4: float, b5: float) -> Observation:

    observation = study.new_observation()
    #tensor of design values to reconstruct
    x = torch.tensor([b1, b2, b3, b4, b5])    
    observation.add(model(x).tolist())

    return observation

# Run the minimization
study.set_evaluator(evaluate)
study.run()
best_sample = study.driver.best_sample
min_chisq = study.driver.min_objective
uncertainties = study.driver.uncertainties
print(f"Reconstructed parameters with chi-squared value {min_chisq:.4e}:")
for param in design_space:
    name = param['name']
    print(f"  {name} = {best_sample[name]:.3f} +/- {uncertainties[name]:.3f}")

# Before running a Markov-chain Monte-Carlo (MCMC) sampling we converge the surrogate
# models by sampling around the minimum. To make the study more explorative, the
# scaling parameter is increased and the effective degrees of freedom is set to one.
study.configure(
    scaling=10.0,
    effective_DOF=1.0, 
    min_uncertainty=min_chisq*1e-8,
    max_iter=150,
    min_val=0.0
)
study.run()

# Run the MCMC sampling with 32 walkers
num_walkers, max_iter = 32, 10000
mcmc_result = study.driver.run_mcmc(
    rel_error=0.01,
    num_walkers=num_walkers,
    max_iter=max_iter
)
minimum = torch.tensor([3.7541005211E-01, 1.9358469127E+00, -1.4646871366E+00,
              1.2867534640E-01,2.2122699662E-01])
fig = corner.corner(
    np.array(mcmc_result['samples']),
    quantiles=(0.16, 0.5, 0.84),
    levels=(1-np.exp(-1.0), 1-np.exp(-0.5)),
    show_titles=True, scale_hist=False,
    title_fmt=".3f",
    labels=[d['name'] for d in design_space],
    truths=minimum.numpy()
)
plt.savefig("corner_surrogate.svg", transparent=True) 

# As a comparison, we run the MCMC sampling directly on the analytic model.
p0 = 0.05*np.random.randn(num_walkers, len(design_space))
p0 += minimum.numpy()
min_chisq_cert = 5.4648946975E-05 #certified minimum of MGH17 problem
#reduced standard error sqrt(chisq/DOF) to scale measurement uncertainties
RSE = np.sqrt(min_chisq_cert/(len(target)-len(design_space)))

#log probability function
def log_prob(x):
    out = -0.5*np.sum(((model(torch.tensor(x))-target)/RSE).numpy()**2)
    if np.isnan(out): return -np.inf
    return out

sampler = emcee.EnsembleSampler(
    nwalkers=num_walkers, ndim=len(design_space), log_prob_fn=log_prob
)

#burn-in phase
state = sampler.run_mcmc(p0, 100)
sampler.reset()
#actual MCMC sampling 
sampler.run_mcmc(state, max_iter, progress=True)
samples = sampler.get_chain(flat=True)
fig = corner.corner(
    samples, quantiles=(0.16, 0.5, 0.84),
    levels=(1-np.exp(-1.0), 1-np.exp(-0.5)),
    show_titles=True, scale_hist=False,
    title_fmt=".3f",
    labels=[d['name'] for d in design_space],
    truths=minimum.numpy()
)
plt.savefig("corner_analytic.svg", transparent=True)

