from utils import tensorFromPair
import torch.optim as optim    
import random
import torch.nn as nn 

from utils import *
import config 

from evaluate import *

teacher_forcing_ratio = 0.3


def train(input, target, encoder, decoder, encoder_optim, decoder_optim, criterion, model_type, max_len):
    encoder_hidden = encoder.initHidden()
    
    encoder_optim.zero_grad()
    decoder_optim.zero_grad()
    
    input_length = input.size(0)
    output_length = target.size(0)
    encoder_outputs = torch.zeros(max_len, encoder.hidden_dims*2, device=device)
    loss = 0

    
    for i in range(1,input_length):
        encoder_output, encoder_hidden = encoder(input[i], encoder_hidden)
        encoder_outputs[i] = encoder_output[0,0]
        
    decoder_input = torch.tensor([[SOS_token]], device = device)
    decoder_hidden = encoder_hidden.to(device)
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    # use_teacher_forcing = False
    use_attention = True
    
    if use_attention:
        if use_teacher_forcing:
            # Teacher forcing: Feed the target as the next input
            for di in range(output_length):
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                loss += criterion(decoder_output, target[di])
                decoder_input = target[di]  # Teacher forcing

        else:
            # Without teacher forcing: use its own predictions as the next input
            for di in range(output_length):
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()  # detach from history as input

                loss += criterion(decoder_output, target[di])
                if decoder_input.item() == EOS_token:
                    break
    else:
        
        if use_teacher_forcing:
            for i in range(output_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
                loss += criterion(decoder_output, target[i]) 
                
        else:
            for i in range(1,output_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach() # detach from history as input
            
                loss += criterion(decoder_output, target[i])
                if decoder_input.item() == EOS_token:
                    break
    loss.backward()
    
    encoder_optim.step()
    decoder_optim.step()
    
    return loss.item()/output_length
    
def validation(input, target, encoder, decoder, criterion, max_len):
    encoder_hidden = encoder.initHidden()
    input_length = input.size(0)
    output_length = target.size(0)
    encoder_outputs = torch.zeros(max_len, encoder.hidden_dims*2, device=device)
    loss = 0

    
    for i in range(1,input_length):
        encoder_output, encoder_hidden = encoder(input[i], encoder_hidden)
        encoder_outputs[i] = encoder_output[0,0]
        
    decoder_input = torch.tensor([[SOS_token]], device = device)
    decoder_hidden = encoder_hidden.to(device)
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    # use_teacher_forcing = False
    use_attention = True
    
    if use_attention:
        if use_teacher_forcing:
            # Teacher forcing: Feed the target as the next input
            for di in range(output_length):
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                loss += criterion(decoder_output, target[di])
                decoder_input = target[di]  # Teacher forcing

        else:
            # Without teacher forcing: use its own predictions as the next input
            for di in range(output_length):
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()  # detach from history as input

                loss += criterion(decoder_output, target[di])
                if decoder_input.item() == EOS_token:
                    break
    else:
        
        if use_teacher_forcing:
            for i in range(output_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
                loss += criterion(decoder_output, target[i]) 
                
        else:
            for i in range(1,output_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach() # detach from history as input
            
                loss += criterion(decoder_output, target[i])
                if decoder_input.item() == EOS_token:
                    break
    
    return loss.item()/output_length   


def trainIter(pairs,encoder, decoder, iters, print_, plot, input_lang, output_lang, learning_rate, num_pairs, val_pairs, bleu_plot, model_type):
    plot_loss  = []
    print_loss = 0
    
    plot_loss_tot = 0
    print_loss_tot=0
    
    val_loss = 0
    val_loss_total = 0
    plot_val_loss = []
    val_loss_total_print=0
    val_loss_total_plot = 0
    
    encoder_optim = optim.Adam(encoder.parameters(), lr = learning_rate)
    decoder_optim = optim.Adam(decoder.parameters(), lr = learning_rate)   
    
    tot_bscores = [[],[],[],[],[],[],[],[],[],[],[]]
    
    training_pairs = [tensorFromPair (input_lang, output_lang, random.choice(pairs)) for i in range(num_pairs)]
    validation_pairs = [tensorFromPair(input_lang, output_lang, pair) for pair in val_pairs]
    criterion = nn.NLLLoss()
    
    
    for iter in range(1, iters+1):
        train_pair = training_pairs[num_pairs-1]
        input_seq = train_pair[0].to(device)
        target_seq = train_pair[1].to(device)
        
        t_loss = train(input_seq, target_seq, encoder, decoder,encoder_optim, decoder_optim, criterion, model_type, config.MAX_LENGTH )
                
        print_loss_tot += t_loss
        plot_loss_tot += t_loss
        
        
        pair = random.choice(validation_pairs)
        input_seq = pair[0].to(device) 
        target_seq = pair[1].to(device)
        
        with torch.no_grad():
            val_loss = validation(input_seq, target_seq, encoder, decoder,criterion, config.MAX_LENGTH)
            val_loss_total_print += val_loss
            val_loss_total_plot  += val_loss
        
        if iter % print_ == 0:
            print_loss_avg = print_loss_tot / print_
            print_loss_tot = 0
            
            val_loss_avg = val_loss_total_print/print_
            val_loss_total_print=0
            
            print("Epoch : ",iter, " ", iter,"/",  iters , "  Train loss: ", print_loss_avg, "  Val. Loss : ",val_loss_avg )
            
        if iter % plot == 0:
            plot_loss_avg = plot_loss_tot / plot
            plot_loss.append(plot_loss_avg)
            plot_loss_tot = 0
            
            val_loss_avg = val_loss_total_plot / plot
            plot_val_loss.append(val_loss_avg)
            val_loss_total_plot = 0
            
            # att_seq = random.choice(val_pairs)
            # EvaluationShowAttention(encoder, decoder, input_lang, output_lang, att_seq[0])
            
            
            showPlot(plot_loss, plot_val_loss, model_type)
            
        if iter%bleu_plot==0:
            temp = evaluate_pairs_bleu(val_pairs, encoder, decoder, input_lang, output_lang)
            for i in range(len(temp)):
                tot_bscores[i].append(temp[i])
            plot_bleu_scores(tot_bscores, model_type)
            
        torch.save(encoder.state_dict(), model_type+"encoder.pt")
        torch.save(decoder.state_dict(), model_type+"decoder.pt")  
        if iters >= 800:
            evaluateRandomlyWithFileIO(encoder, decoder, pairs, input_lang, output_lang, model_type, n=1)