from __future__ import unicode_literals, print_function, division
import torch
import torch.nn as nn   
import torch.nn.functional as F 

import re 
import random 
import matplotlib.pyplot as plt

from dataset import prepareData
from models import Encoder, Decoder, AttentionDecoder
from utils import indexesFromSequence, tensorFromSequence, tensorFromPair
from train import *
import config

MAX_LENGTH = 339
device = config.device
print("Device : ", device)

SOS_token = 0
EOS_token = 1


input_lang, output_lang, pairs = prepareData('CovNormal','CovMutated', False) 

# Use to check the training dataset
# for i in range(2):
#     print(random.choice(pairs))


# Model Hyperparameters
hidden_size=256

# Encoder model
encoder = Encoder(input_lang.num_trigrams, hidden_size, config.model_type, config.num_layers).to(device)
print(encoder)

# Decoder Model without Attention
# decoder = Decoder(output_lang.num_trigrams, hidden_size, model_type).to(device)

# Decoder Model with Attention
decoder = AttentionDecoder(hidden_size,output_lang.num_trigrams, 0.1,MAX_LENGTH, config.model_type, config.num_layers).to(device)
print(decoder)

# Removing the pairs  missing input or target values
final_pairs = []
for i in range(len(pairs)):
    if len(pairs[i]) == 2:
        temp1 = pairs[i][0] 
        temp2 = pairs[i][1]
        pairs[i][0] =  temp1
        pairs[i][1] = temp2
        final_pairs.append(pairs[i])
        
random.shuffle(final_pairs)
random.shuffle(final_pairs) 

print_ = 1       
epochs = 2200
lr = 0.0001
plot = 20
bleu_plot = 100
#val_pairs = final_pairs[len(final_pairs)-1000: len(pairs)]
#final_pairs = final_pairs[0:len(final_pairs)-1000]
val_pairs = final_pairs[150: 223]
final_pairs = final_pairs[0:150] 
num_pairs = len(final_pairs) * 2

# Train Loop        
#trainIter(final_pairs, encoder, decoder, epochs, print_, plot, input_lang, output_lang, lr, num_pairs, val_pairs, bleu_plot, config.model_type)

# Load the saved pretrained models
encoder.load_state_dict(torch.load(str(config.model_type)+'encoder.pt'))
decoder.load_state_dict(torch.load(str(config.model_type)+'decoder.pt'))

# # Evaluation
encoder.eval() 
decoder.eval()



# Evaluate randomly the models with sequence pairs 
# (encoder, decoder, pairs, input_lang, output_lang, model_type, n):
#evaluateRandomlyWithFileIO(encoder, decoder,val_pairs, input_lang, output_lang,config.model_type, n=30)

