import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm, uniform
from iminuit import Minuit
from iminuit.cost import ExtendedUnbinnedNLL

# Set random seed for reproducibility
np.random.seed(1234)

##### Generate dataset #####
Nsig_true = 1000
Nbkg_true = 2000

# Discriminant x: used for fit
x_sig = np.random.normal(loc=5.0, scale=1.0, size=Nsig_true)
x_bkg = np.random.uniform(0, 10, size=Nbkg_true)

# Discriminant y: for signal shape analysis
y_sig = np.random.normal(loc=0.0, scale=1.0, size=Nsig_true)
y_bkg = np.random.normal(loc=3.0, scale=1.0, size=Nbkg_true)

# Combine data
x_all = np.concatenate([x_sig, x_bkg])
y_all = np.concatenate([y_sig, y_bkg])

##### Define PDFs #####
def pdf_signal(x, mu=5.0, sigma=1.0):
    return norm.pdf(x, loc=mu, scale=sigma)

def pdf_background(x, xmin=0, xmax=10):
    return uniform.pdf(x, loc=xmin, scale=xmax - xmin)

def total_pdf(x, Nsig, Nbkg):
    return Nsig * pdf_signal(x) + Nbkg * pdf_background(x)

##### Define and minimize Extended Unbinned Negative Log Likelihood #####

def model(x, Nsig, Nbkg):     # Model function needed to compute likelihood
    return Nsig + Nbkg, total_pdf(x, Nsig, Nbkg)

cost = ExtendedUnbinnedNLL(x_all, model)  # Cost function for minimisation

minuit = Minuit(cost, Nsig=500, Nbkg=2500) # Initialise with seed values
minuit.limits["Nsig"] = (0, None)          # Constrain signal > 0
minuit.limits["Nbkg"] = (0, None)          # Constrain background > 0
minuit.migrad()                            # Minimise

Nsig_fit, Nbkg_fit = minuit.values["Nsig"],minuit.values["Nbkg"] # Fit results

cov = minuit.covariance                           # Get covariance
cov_array = np.array([                            # Format as an array
    [cov["Nsig", "Nsig"], cov["Nsig", "Nbkg"]],
    [cov["Nbkg", "Nsig"], cov["Nbkg", "Nbkg"]],])

##### Compute sWeights using covariance matrix #####
norm_pdf = Nsig_fit * pdf_signal(x_all) + Nbkg_fit * pdf_background(x_all)
norm_pdf_safe = np.where(norm_pdf > 0, norm_pdf, 1e-10)  # Prevent divi by 0

# sWeights formula using covariance matrix directly
sweight_signal = (cov_array[0, 0]*fs + cov_array[0,1]*fb)/norm_pdf_safe
sweight_background = (cov_array[1, 0]*fs + cov_array[1,1]*fb)/norm_pdf_safe

##### Plot x and sWeighted y distributions #####
bins = np.linspace(0, 10, 50)
plt.xlim(0,10)
plt.hist(x_sig,bins,histtype='step',linewidth=2,label='Signal x',color='black')
plt.hist(x_bkg,bins,histtype='step',linewidth=2,label='Background x',
         linestyle='dotted',color='black')
plt.xlabel("Discriminant x",fontsize=15)
plt.ylabel("Events",fontsize=15)
plt.legend()
plt.title("Discriminant x distributions for signal and background",fontsize=15)
plt.show()

bins = np.linspace(-4, 7, 50)
plt.xlim(-4,7)
plt.hist(y_all, bins, weights=sweight_signal, histtype='stepfilled',
         alpha=0.6, label='Signal (sWeighted)', color='grey')
plt.hist(y_all, bins, weights=sweight_background, histtype='stepfilled',
         alpha=0.6, label='Background (sWeighted)', color='lightgrey')
plt.hist(y_sig,bins,histtype='step',linewidth=2,label='True Signal y',
         color='black')
plt.hist(y_bkg, bins, histtype='step', linewidth=2, label='True Background y',
         color='black', linestyle='dotted')
plt.xlabel("Discriminant y",fontsize=15)
plt.ylabel("Events (sWeighted)",fontsize=15)
plt.legend()
plt.title("Plot of discriminant y using sWeights from discriminant x",fontsize=15)
plt.grid(True)
plt.show()
