import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import os, glob
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size

import datetime
import imageio

import matplotlib.pylab as pylab
params = {'legend.fontsize': 14,
          'figure.figsize': (10, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large',
         'figure.titlesize' : 18}
pylab.rcParams.update(params)

def plotSphereHeat(T, imgNum, saveTo = None):
    # Creating the theta and phi values.
    intervals = T.shape[0]
    ntheta = intervals
    nphi = 2*intervals

    theta = np.linspace(0, np.pi*1, ntheta+1)
    phi   = np.linspace(0, np.pi*2, nphi+1)

    # Creating the coordinate grid for the unit sphere.
    X = np.outer(np.sin(theta), np.cos(phi))
    Y = np.outer(np.sin(theta), np.sin(phi))
    Z = np.outer(np.cos(theta), np.ones(nphi+1))

    # Creating the colormap thingies.
    cm = mpl.cm.afmhot
    sm = mpl.cm.ScalarMappable(cmap=cm )
    sm.set_array([])

    # Creating the plot.
    fig = plt.figure(figsize=(15,12))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=cm(T), alpha=0.5)

    plt.colorbar(sm, shrink = 0.5)
    
    if saveTo != None:
        plt.savefig(saveTo)
        plt.close()
    else:
        # Showing the plot.
        plt.show()

def create_gif(imgDirectory, duration, output, xMin = 0, xMax = -1, yMin = 0, yMax = -1):
    images = []
    filenames = glob.glob(imgDirectory+"/*.png")
    filenames.sort()
    for filename in filenames:
        img = imageio.imread(filename)[yMin:yMax,xMin:xMax,:]
        images.append(img)
    output_file = output+'-%s.gif' % datetime.datetime.now().strftime('%Y-%M-%d-%H-%M-%S')
    imageio.mimsave(output_file, images, duration=duration)
    

def display_gif(fn):
    from IPython import display
    return display.HTML('<img src="{}">'.format(fn))


def getHeatSource(t = 0):
    w = 1
    theta0 = np.pi/3
    
    intervals = 50
    ntheta = intervals
    nphi = 2*intervals

    theta = np.linspace(0, np.pi*1, ntheta).reshape([ntheta,1])
    phi   = np.linspace(0, np.pi*2, nphi).reshape([1,nphi])
    source = np.matmul(np.cos(theta)*np.cos(theta0) + np.sin(theta)*np.sin(theta0),
                       np.cos(phi - w*t))
    source[source < 0] = 0
    
    return sourc