import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

seed=np.random.seed(123456)    # Define random seed

# 1. Define the original function
def original_function(x):
    scale=1.6
    fun= abs((x/scale) * np.sin((x/scale)))
    return fun

# 2. Set up the x range with extension for edge effects
x_min, x_max = 0, 15
extension_width = 3  # Extend by 3 bins on each side
extended_x_min = x_min - extension_width
extended_x_max = x_max + extension_width
n_points = 1000
x = np.linspace(extended_x_min, extended_x_max, n_points)

# Compute the original function over the extended range
f_true_extended = original_function(x)

# 3. Create the Gaussian response function over the extended range
sigma_response = 1.5                   # Smearing sigma in units of x
dx = x[1] - x[0]
x_mid = (extended_x_min+extended_x_max)/2   # Midpoint of the x range
response_kernel = np.exp(-0.5 * ((x - x_mid) / sigma_response)**2)
response_kernel /= response_kernel.sum()    # Normalise

# 4. Convolve the extended true function with the response function
f_convolved_extended = np.convolve(f_true_extended,
                                   response_kernel, mode='same')*dx

# 5. Sample data from the convolved extended data
n_samples = 10000                       # Total number of counts sampled
bins = np.linspace(x_min, x_max, 16)    # 15 bins
bin_centers = 0.5 * (bins[:-1] + bins[1:])
xdata=np.linspace(x_min, x_max, n_points) # Use original range for data
f_true = original_function(xdata)
# Interpolate the convolved data to match the sampling points
interp_convolved = interp1d(x, f_convolved_extended, kind='linear')
f_convolved = interp_convolved(np.linspace(x_min, x_max, n_points))
p = f_convolved / np.sum(f_convolved)   # Sample data from convolved function
sample_x = np.random.choice(xdata, size=n_samples, p=p)
bin_counts, _ = np.histogram(sample_x, bins=bins)
d = bin_counts.astype(float)

# 6. Create the response matrix R considering the extended range
# For each bin center, compute the response kernel over the extended x-range
R = np.zeros((len(bin_centers), n_points))
for i, center in enumerate(bin_centers):
    kernel = np.exp(-0.5 * ((x - center) / sigma_response)**2)
    kernel /= kernel.sum()
    R[i, :] = kernel * dx


# 7. Define the Tikhonov regularization function
def tikhonov_inversion(R, d, alpha, regularization_operator):
    # Compute approximate solution for f
    f_0 = np.linalg.solve(R.T @ R + (alpha ** 2) *
                          (regularization_operator.T @
                           regularization_operator), R.T @ d)
    m_0 = R @ f_0     # First estimate for model mean predictions
    epsilon = 1e-8    # Used to prevent divide by zero

    # Define approximate weight matrix for 2nd iteration
    W = np.diag(1/(m_0 + epsilon))

    # Set up the linear equations for the minimised solution and solve for f
    A = R.T @ W @ R + (alpha** 2) * (regularization_operator.T @
                                     regularization_operator)
    b = R.T @ W @ d
    f = np.linalg.solve(A, b)

    return f

# 8. Define the second derivative regularization operator
def second_derivative_operator(n_points):
    L2 = np.zeros((n_points-2, n_points))
    for i in range(n_points-2):
        L2[i, i] = 1
        L2[i, i+1] = -2
        L2[i, i+2] = 1
    return L2

L = second_derivative_operator(n_points)

# 9. Define the chi-squared function for optimization
def objective(alpha):
    f = tikhonov_inversion(R, d, alpha, L)      # Solve for f
    m = R @ f                                   # Compute the expected mean
    epsilon = 1e-8                              # Add to avoid log(0)
    chi2 = np.sum(((d-m)**2) / (m + epsilon))   # Pearson's chi-squared
    return chi2

# 10. Generate the L-curve to find optimal alpha
alpha_values = np.logspace(-4, 1, 50)  # Use 50 logarithmically spaces values
residuals = []
roughness = []

for alpha in alpha_values:
    f = tikhonov_inversion(R, d, alpha, L)
    m = R @ f
# For vectors, it's more efficient to compute norms than squares
    epsilon = 1e-8
    residual = np.linalg.norm((d - m)/np.sqrt(m + epsilon))  # Residual norm
    residuals.append(residual)
    roughness_value = np.linalg.norm(L @ f)                  # Roughness norm
    roughness.append(roughness_value)

# 11. Select the optimal alpha from max positive change in slope of L-curve
dslope_best=-1000
for i in range(len(alpha_values)-2):
    slope_1=(np.log(roughness[i+1])-
             np.log(roughness[i]))/(np.log(residuals[i+1])
                                    - np.log(residuals[i]))
    slope_2=(np.log(roughness[i+2])-
             np.log(roughness[i+1]))/(np.log(residuals[i+2])
                                      - np.log(residuals[i+1]))
# Check for inflection point of curve
    if slope_2 > -10 and slope_2-slope_1 > dslope_best:
        dslope_best=slope_2-slope_1
        idx_opt=i+2
alpha_opt = alpha_values[idx_opt]
print(f"Selected regularization parameter alpha: {alpha_opt}")
alpha_opt=0.0001

# 12. Compute the regularized solution with optimal t
f_opt = tikhonov_inversion(R, d, alpha_opt, L)

# Calculate the covariance matrix for the solution to get estimated uncertainty in f
f = tikhonov_inversion(R, d, alpha_opt, L)
m = R @ f
epsilon = 1e-8
W = np.diag(1/(m + epsilon))
A = R.T @ W @ R + (alpha_opt ** 2) * (L.T @ L)
cov_matrix = np.linalg.inv(A)
uncertainty = np.sqrt(np.diag(cov_matrix))

# Normalize all functions for comparison
f_true_norm = f_true/np.sum(f_true)
f_convolved_norm = f_convolved/np.sum(f_convolved)
f_opt_norm = f_opt/np.sum(f_opt)
uncertainty_norm = uncertainty/np.sum(f_opt)

# 13. Plot results with uncertainty
plt.figure(figsize=(12,8))
plt.subplot(2,2,1)
plt.plot(xdata, f_true_norm, color='black',linestyle='dotted', label='Original Function')
plt.plot(xdata, f_convolved_norm, color='black', label='Convolved Function')
plt.title('Original Function',fontsize=15)
plt.xlim(0,15)
plt.xlabel('x',fontsize=15)
plt.ylabel('f(x)',fontsize=15)
plt.legend()

plt.subplot(2,2,2)
plt.bar(bin_centers, d, width=(bins[1]-bins[0]), alpha=0.7,
        color='grey',label='Sampled Data')
plt.title('Sampled Data (Histogram)',fontsize=15)
plt.xlabel('x',fontsize=15)
plt.ylabel('Counts',fontsize=15)
plt.legend()

plt.subplot(2,2,3)
# Plot the L-curve
plt.loglog(residuals, roughness, mec='grey', mfc='none', marker='o')
plt.loglog(residuals[idx_opt], roughness[idx_opt], marker='*',ms=20, color='black')
plt.xlabel(r'Residual Norm ||$(\mathbf{d} - \mathbf{R}\cdot \mathbf{f})$/' +
           r'$\sqrt{\mathbf{R}\cdot \mathbf{f}}$||',fontsize=15)
plt.ylabel(r'Roughness Norm ||$\mathbf{L}\cdot \mathbf{f}$||',fontsize=15)
plt.title('L-curve for Regularization Parameter Selection',fontsize=15)

plt.subplot(2,2,4)
# Plot the regularized solution with uncertainty band
plt.xlim(0,15)
plt.ylim(-0.001,0.004)
plt.ylim(-0.01,0.01)
plt.plot(xdata, f_true_norm, label='Original Function',linestyle='dotted',color='black')
plt.plot(x, f_opt_norm, label='Reconstructed Function',color='black')
plt.fill_between(x, f_opt_norm - uncertainty_norm,
                 f_opt_norm + uncertainty_norm,
                 color='grey', alpha=0.2, label='Uncertainty')
plt.title('Reconstructed Function with Estimated Uncertainty',fontsize=14)
plt.xlabel('x',fontsize=15)
plt.ylabel('f(x)',fontsize=15)
plt.legend()

plt.tight_layout()
plt.show()
