"""GP tutorial for CSCE790

More references (on which this tutorial is based on):
- https://github.com/SheffieldML/notebook
- http://www.cs.ubc.ca/~nando/540-2013/lectures.html
- http://nbviewer.jupyter.org/github/gpschool/gprs15a/blob/master/GPy%20introduction%20covariance%20functions.ipynb
  
Author: Alberto Quattrini Li

"""

import numpy as np # Library for matrices operations
import matplotlib.pyplot as plt # Library for plotting

# Create training data
x = np.linspace(0.05,0.95,10)[:,None]
y = -np.cos(np.pi*x) + np.sin(4*np.pi*x) + np.random.normal(loc=0.0, scale=0.1, size=(10,1))

# Plot of the training data
import matplotlib.pyplot as plt
plt.plot(x, y, 'rx')
plt.show()

# Create prediction points
num_pred_data = 50 # how many points to use for plotting predictions
x_pred = np.linspace(0.0, 1.0, num_pred_data)[:, None] # input locations for predictions

# Definition of Kernel function (from http://nbviewer.jupyter.org/github/gpschool/gprs15a/blob/master/GPy%20introduction%20covariance%20functions.ipynb)
def exponentiated_quadratic(x, x_prime, variance, lengthscale):
    squared_distance = ((x-x_prime)**2).sum()
    return variance*np.exp((-0.5*squared_distance)/lengthscale**2)

def compute_kernel(X, X2, kernel, **kwargs):
    K = np.zeros((X.shape[0], X2.shape[0]))
    for i in np.arange(X.shape[0]):
        for j in np.arange(X2.shape[0]):
            K[i, j] = kernel(X[i, :], X2[j, :], **kwargs)
        
    return K

# Input parameters for Kernel function
# set covariance function parameters
variance = 1
lengthscale = 0.1

# Calculation of the Kernel matrix on the training data and prediction
K = compute_kernel(x, x, exponentiated_quadratic, variance=variance, lengthscale=lengthscale)
K_star = compute_kernel(x, x_pred, exponentiated_quadratic, variance=variance, lengthscale=lengthscale)
K_starstar = compute_kernel(x_pred, x_pred, exponentiated_quadratic, variance=variance, lengthscale=lengthscale)
# Plot of the kernel
fig, ax = plt.subplots(figsize=(8,8))
im = ax.imshow(np.vstack([np.hstack([K, K_star]), np.hstack([K_star.T, K_starstar])]), interpolation='none')
# Add lines for separating training and test data
ax.axvline(x.shape[0]-1, color='w')
ax.axhline(x.shape[0]-1, color='w')
fig.colorbar(im)

plt.show()

# Prior plot, GP assumed to be zero-mean, 10 functions drawn from the GP
for i in xrange(10):
    y_sample = np.random.multivariate_normal(mean=np.zeros(x_pred.size), cov=K_starstar)
    plt.plot(x_pred.flatten(), y_sample.flatten())

plt.show()

# Definition of the GP class
class GP():
    def __init__(self, X, y, sigma2, kernel, **kwargs):
        self.K = compute_kernel(X, X, kernel, **kwargs)
        self.X = X
        self.y = y
        self.sigma2 = sigma2
        self.kernel = kernel
        self.kernel_args = kwargs
        self.update_inverse()
    
    def update_inverse(self):
        # Precompute the inverse covariance and some quantities of interest
        ## NOTE: This is not the correct *numerical* way to compute this! It is for ease of use.
        self.Kinv = np.linalg.inv(self.K+self.sigma2*np.eye(self.K.shape[0]))
        # the log determinant of the covariance matrix.
        self.logdetK = np.linalg.det(self.K+self.sigma2*np.eye(self.K.shape[0]))
        # The matrix inner product of the inverse covariance
        self.Kinvy = np.dot(self.Kinv, self.y)
        self.yKinvy = (self.y*self.Kinvy).sum()

# Function for calculating the posterior
def posterior_f(self, X_test):
    K_star = compute_kernel(self.X, X_test, self.kernel, **self.kernel_args)
    A = np.dot(self.Kinv, K_star)
    mu_f = np.dot(A.T, y)
    C_f = K_starstar - np.dot(A.T, K_star)
    return mu_f, C_f

# attach the new method to class GP():
GP.posterior_f = posterior_f
    
# set noise variance
sigma2 = 0.2

# Instantiate the GP model
model = GP(x, y, sigma2, exponentiated_quadratic, variance=variance, lengthscale=lengthscale)
# Calculate the posterior
mu_f, C_f = model.posterior_f(x_pred)
# Plot the covariance of the posterior
fig, ax = plt.subplots(figsize=(8,8))
im = ax.imshow(C_f, interpolation='none')
fig.colorbar(im)
plt.show()

# Plot the function that goes through the mean with confidence
var_f = np.diag(C_f)[:, None]
std_f = np.sqrt(var_f)

plt.plot(x, y, 'rx')
plt.plot(x_pred, mu_f, 'b-')
plt.plot(x_pred, mu_f+2*std_f, 'b--')
plt.plot(x_pred, mu_f-2*std_f, 'b--')
plt.show()

# Note that the hyperparameters could be optimized by maximizing the log likelihood of the posterior

# GPy library can be used.