
import unicodedata
import torch 
import config
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from random import randrange


SOS_token = 0
EOS_token = 1
device = config.device


def indexesFromSequence(lang, sequence):
    return [lang.trigram2index[trigram] for trigram in sequence]#.split(' ')]
    
def tensorFromSequence(lang, sequence):
    indexes = indexesFromSequence(lang, sequence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device = device).view(-1,1)
    
def tensorFromPair(input_lang, output_lang, pair):
    input = tensorFromSequence(input_lang, pair[0])
    output = tensorFromSequence(output_lang, pair[1])
    return (input, output)
    
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )
    

def normalizeString2(s):
    s = unicodeToAscii(s.strip())    
    # s = re.sub(r"([.!?])", r" \1", s)
    # s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = s.split()
    return s
    

def showPlot(points1, points2, model_type):
    plt.figure()
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    # fig, ax = plt.subplots()
    # # this locator puts ticks at regular intervals
    # loc = ticker.MultipleLocator(base=0.2)
    # ax.yaxis.set_major_locator(loc)
    plt.plot(points1, label='Train Loss')
    plt.plot(points2, label = 'Validation Loss')
    plt.legend(framealpha=1, frameon=True)
    plt.savefig(config.model_type2+'plot.png')
    plt.close("all")
    
def plot_bleu_scores(scores, model_type): 
    plt.figure() 
    plt.xlabel("Iterations")
    plt.ylabel("BLeU Scores")
    for i in range(len(scores)):
        plt.plot(scores[i], label=str(i+1)+'-gram') 
    plt.legend(framealpha=1, frameon=True)
    plt.savefig(config.model_type2+"bleu.png")
    plt.close("all")


def listToString(s): 
    str1 = "" 
    for ele in s: 
        str1 += ele  
    return str1 
    
def showAttention(input_seq, output_seq, attentions): 
    # from matplotlib  import ticker

    # fig = plt.figure()
    # fig.set_size_inches(100,25)

    # ax = fig.add_subplot(111)
    # cax = ax.matshow(attentions.numpy(), cmap='bone')
    
    
    # # set up axes
    # ax.set_xticklabels(['']+input_seq + ['<EOS>'], rotation=90)
    # ax.set_yticklabels([''] + output_seq)

    # # ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    # # ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    # ax.xaxis.set_major_locator(ticker.MaxNLocator(1277))
    # ax.yaxis.set_major_locator(ticker.MaxNLocator(1277))
    # ax.text(0.0, 0.1, "AutoLocator()", fontsize=1, transform=ax.transAxes)
    
    
    
    
    import seaborn as sns
    print(attentions.numpy())
    fig = plt.figure()
    fig.set_size_inches(50,40)
    temp = attentions.numpy()
    temp2 = temp[0:638,0:638]
    temp = temp[638:-1, 638:-1]
    ax = sns.heatmap(temp2, cmap="YlGnBu", linecolor="red")
    plt.show()
    a = randrange(100)
    plt.savefig("/home/ahtisham/AIMPID/structured/Attentions/"+config.model_type+"_"+str(a)+"_"+"temp2_Attention.png", dpi=65)
    
    ax = sns.heatmap(temp, cmap="YlGnBu", linecolor="red")
    plt.show()
    plt.savefig("/home/ahtisham/AIMPID/structured/Attentions/"+config.model_type+"_"+str(a)+"_"+" temp_Attention.png", dpi=65)



    # plt.figure()
    # cs = plt.contourf(attentions.numpy()*100,
    # colors=['#808080', '#A0A0A0', '#C0C0C0'], extend='both')
    # cs.cmap.set_over('red')
    # cs.cmap.set_under('blue')
    # cs.changed()
    
    # from matplotlib import pylab
    # import scipy.io
    # scipy.io.savemat('arrdata.mat', mdict={'arr': attentions.numpy()})


    
    plt.savefig("/home/ahtisham/AIMPID/structured/Attentions/"+config.model_type+"_"+str(randrange(100))+"_"+"Attention.png", dpi=65)
    
    
