import sys,os
import numpy as np
import time
from scipy.stats import multivariate_normal, uniform
import pandas as pd
import torch
torch.set_default_dtype(torch.float64)
import matplotlib.pyplot as plt
import corner #run "pip install corner" if not installed
import emcee #run "pip install emcee" if not installed


jcm_optimizer_path = r"<JCM_OPTIMIZER_PATH>"
sys.path.insert(0, os.path.join(jcm_optimizer_path, "interface", "python"))
from jcmoptimizer import Server, Client, Study, Observation
server = Server()
client = Client(server.host)

# Definition of the search domain
design_space = [
    {'name': 'b1', 'type': 'continuous', 'domain': (0,10)}, 
    {'name': 'b2', 'type': 'continuous', 'domain': (0.1,4)},
    {'name': 'b3', 'type': 'continuous', 'domain': (-4,-0.1)},
    {'name': 'b4', 'type': 'continuous', 'domain': (0.05,1)},
    {'name': 'b5', 'type': 'continuous', 'domain': (0.05,1)}
]

# Creation of the study object with study_id 'bayesian_reconstruction'
study = client.create_study(
    design_space=design_space,
    driver="BayesianReconstruction",
    name="Bayesian parameter reconstruction using Bayesian optimization",
    study_id="bayesian_reconstruction"
)
#The vectorial model function of the MGH17 problem
def model(x: torch.Tensor) -> torch.Tensor:
    s = torch.arange(33)
    return x[0] + x[1]*torch.exp(-s*x[3]) + x[2]*torch.exp(-s*x[4])

#The forward model parameters b to be reconstructed
b_true = torch.tensor([3.7541005211E-01, 1.9358469127E+00, -1.4646871366E+00,
              1.2867534640E-01,2.2122699662E-01])

#The error model parameters in log-space to be reconstructed
log_c1, log_c2 = np.log(0.005), np.log(0.01)

#The error model, i.e. the noise stddev depending on the model value y=f(b)
def error_model(log_c1: float, log_c2: float, y: torch.Tensor) -> torch.Tensor:
    return torch.sqrt( np.exp(log_c1)**2 + (np.exp(log_c2)*y)**2)

#Generate a rantom target vector of measurements
torch.manual_seed(0)
model_vector = model(b_true)
err = error_model(log_c1,log_c2,model_vector)
measurement_vector = model_vector + err*torch.randn(model_vector.shape)

study.configure(
    max_iter=80,
    target_vector=measurement_vector.tolist(),
    error_model=dict(
        #error model expression
        expression='sqrt(exp(log_c1)^2 + (exp(log_c2)*y_model)^2)',
        #prior distribution of error model parameters
        distributions=[
            {'type': 'normal', 'parameter': 'log_c1', 'mean': -5.0, 'stddev': 1},
            {'type': 'normal', 'parameter': 'log_c2', 'mean': -4.0, 'stddev': 1},
        ],
        #initial values and parameter bounds for fitting the error model parameters
        initial_parameters=[-5.0, -4.0],
        parameter_bounds=[(-7,-2), (-6.0,-1)]
    ),
    #Multivariate normal prior distribution of forward model parameters.
    #Unspecified parameters (b1,b4,b5) are uniformly distributed in the design space
    parameter_distribution=dict(
         distributions=[
             dict(type='mvn', parameters=['b2','b3'], mean=[2.25,-2.0], 
                  covariance=[[0.5,-0.01], [-0.01,0.5]]),                  
         ]
    )    
)
# Evaluation of the black-box function for specified design parameters
def evaluate(study: Study, b1: float, b2: float, b3: float, b4: float, b5: float) -> Observation:

    observation = study.new_observation()
    #tensor of design values to reconstruct
    x = torch.tensor([b1, b2, b3, b4, b5])    
    observation.add(model(x).tolist())

    return observation

# Run the minimization
study.set_evaluator(evaluate)
study.run()
#best sample of forward model parameters b
best_b_sample = study.driver.best_sample
#minimum of negative log-probability
min_neg_log_prob = study.driver.min_objective

#determine sample [b1, b2, b3, b4, b5, log_c1, log_c2] that minimizes the negative
#log-probability
minimum = list(best_b_sample.values())

#path to negative log-probability variable
path = "driver.acquisition_function.main_objective.variable"
neg_log_probs = study.historic_parameter_values(f"{path}.observed_value")
idx = np.argmin(neg_log_probs)

logs_c1 = study.historic_parameter_values(f"{path}.error_model_parameters.log_c1")
logs_c2 = study.historic_parameter_values(f"{path}.error_model_parameters.log_c2")

minimum += [logs_c1[idx], logs_c2[idx]]
minimum = np.array(minimum)


# Before running a Markov-chain Monte-Carlo (MCMC) sampling we converge the surrogate
# models by sampling around the minimum. To make the study more explorative, the
# scaling parameter is increased and the effective degrees of freedom is set to one.
study.configure(
    scaling=10.0,
    effective_DOF=1.0, 
    min_uncertainty=1e-8*np.abs(min_neg_log_prob),
    max_iter=120,
)
study.run()

# Run the MCMC sampling with 32 walkers
num_walkers, max_iter = 32, 20000
mcmc_result = study.driver.run_mcmc(
    rel_error=0.01,
    num_walkers=num_walkers,
    max_iter=max_iter
)
fig = corner.corner(
    np.array(mcmc_result['samples']),
    quantiles=(0.16, 0.5, 0.84),
    levels=(1-np.exp(-1.0), 1-np.exp(-0.5)),
    show_titles=True, scale_hist=False,
    title_fmt=".3f",
    labels=[d['name'] for d in design_space] + ["log_c1", "log_c2"],
    truths=minimum
)
plt.savefig("corner_surrogate.svg", transparent=True) 


# As a comparison, we run the MCMC sampling directly on the analytic model.
p0 = 0.05*np.random.randn(num_walkers, len(design_space)+2)
p0 += minimum

# Uniform prior domain for the model parameters b1, b4, b5
uniform_domain = np.array([design_space[idx]["domain"] for idx in [0,3,4]])

#log probability function
def log_prob(x: np.ndarray) -> np.ndarray:

    y = model(x[:5])
    res = y - measurement_vector
    err = error_model(x[5], x[6], y)

    #log-likelihood
    ll = -0.5*(
        torch.log(2*torch.tensor(torch.pi)) +
        torch.log(err**2) +
        (res/err)**2
    ).sum()

    #log prior
    lp = (
        multivariate_normal.logpdf(
            x[1:3],
            mean=[2.25,-2.0], 
            cov=[[0.5,-0.01], [-0.01,0.5]]
        ) + 
        uniform.logpdf(
            [x[0],x[3],x[4]],
            loc=uniform_domain[:,0],
            scale=uniform_domain[:,1] - uniform_domain[:,0]
        ).sum() +
        multivariate_normal.logpdf(
            x[5:],
            mean=[-5.0,-4.0], 
            cov=[[1,0.0], [0.0,1]]
        )
    )

    #log probability
    return ll + lp
sampler = emcee.EnsembleSampler(
    nwalkers=num_walkers, ndim=len(design_space)+2, log_prob_fn=log_prob
)

#burn-in phase
state = sampler.run_mcmc(p0, 100)
sampler.reset()
#actual MCMC sampling 
sampler.run_mcmc(state, max_iter, progress=True)
samples = sampler.get_chain(flat=True)
fig = corner.corner(
    samples, quantiles=(0.16, 0.5, 0.84),
    levels=(1-np.exp(-1.0), 1-np.exp(-0.5)),
    show_titles=True, scale_hist=False,
    title_fmt=".3f",
    labels=[d['name'] for d in design_space] + ["log_c1", "log_c2"],
    truths=minimum
)
plt.savefig("corner_analytic.svg", transparent=True)

