from __future__ import unicode_literals, print_function, division


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim  

import config

device = config.device



class Encoder(nn.Module):
    def __init__(self, embedding_dims, hidden_dims, model_type, num_layers):
        super().__init__()
        self.hidden_dims = hidden_dims
        self.model_type = model_type
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(embedding_dims, hidden_dims)

        if self.model_type == 'GRU': 
            self.rnn = nn.GRU(hidden_dims, hidden_dims, num_layers, bidirectional = True)
            
        elif self.model_type == 'LSTM':
            self.rnn = nn.LSTM(hidden_dims, hidden_dims)
            
        elif self.model_type == 'RNN':
            self.rnn = nn.RNN(hidden_dims, hidden_dims)
        self.dropout = nn.Dropout(0.1)
            
    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1,1,-1)
        embedded = self.dropout(embedded)
        
        if self.model_type == 'RNN' or self.model_type == 'GRU':
            output, hidden = self.rnn(embedded, hidden)
            # hidden = torch.cat((hidden[-3,:,:], hidden[-1,:,:]), dim=1)
            return output, hidden
            
        elif self.model_type == 'LSTM':
            print("LSTM")
            output, (hidden, c_states) = self.rnn(embedded, hidden)
            # hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
            return output, hidden, c_states

    def initHidden(self):
        return torch.zeros(self.num_layers*2,1, self.hidden_dims, device = device)



class Decoder(nn.Module): 

    def __init__(self, out_embedding_dims, hidden_dims, model_type):
        super(Decoder, self).__init__()
        self.hidden_dims = hidden_dims
        self.out_embedding_dims = out_embedding_dims
        self.model_type = model_type
        
        self.embedding = nn.Embedding(self.out_embedding_dims, self.hidden_dims)
        if self.model_type == 'GRU':
            self.rnn = nn.GRU(self.hidden_dims, self.hidden_dims, num_layers =1, dropout = 0.2) 
            
        elif self.model_type == 'LSTM':
            self.rnn = nn.LSTM(self.hidden_dims, self.hidden_dims, dropout = 0.2) 
            
        elif model_type == 'RNN':
            self.rnn = nn.RNN(self.hidden_dims, self.hidden_dims, dropout=0.4) 
        
        self.dense = nn.Linear(self.hidden_dims, self.out_embedding_dims)
        
        self.softmax = nn.LogSoftmax(dim = 1)
        
        
    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        
        embedded = F.relu(embedded)
        if self.model_type == 'GRU' or 'RNN':

            output, hidden = self.rnn(embedded, hidden)
            output = self.dense(output)
            output = self.softmax(output[0])
            return output, hidden
            
        elif self.model_type == 'LSTM':
            output, (hidden,c_states) = self.lstm(embedded, hidden)
            output = self.dense(output)
            output = self.softmax(output[0])
            return output, hidden
    
    
class AttentionDecoder(nn.Module): 
    def __init__(self, hidden_dims, embedding_dims,   dropout_val, max_length, model_type, num_layers): 
        super().__init__()
        self.hidden_dims = hidden_dims
        self.embedding_dims = embedding_dims
        self.dropout_val = dropout_val
        self.model_type = model_type
        self.num_layers = num_layers
        
        
        self.embedding = nn.Embedding(self.embedding_dims, self.hidden_dims)
        self.dropout = nn.Dropout(self.dropout_val)
        
        
        self.attention = nn.Linear(self.hidden_dims *2, max_length)
        self.attention_combine = nn.Linear(self.hidden_dims * 3, self.hidden_dims)
        
        if self.model_type=='GRU': 
            self.rnn = nn.GRU(self.hidden_dims, self.hidden_dims, num_layers, bidirectional=True)
        elif self.model_type == 'RNN': 
            self.rnn = nn.RNN(self.hidden_dims, self.hidden_dims, dropout=0.3)
        elif self.model_type == 'LSTM': 
            self.rnn = nn.LSTM(self.hidden_dims, self.hidden_dims, dropout=0.3)
            
        self.output = nn.Linear(self.hidden_dims*2, self.embedding_dims)
        
        
        
        
    def forward(self, input, hidden, encoder_outputs): 
        embedded = self.embedding(input).view(1,1,-1)
        embedded = self.dropout(embedded)
        
        attention_weights = F.softmax(self.attention(torch.cat((embedded[0], hidden[0]), 1)), dim = 1)
        attention_applied = torch.bmm(attention_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attention_applied[0]),1)
        output = self.attention_combine(output).unsqueeze(0)
        
        output = F.relu(output)
        output, hidden = self.rnn(output, hidden) 
        output = self.output(output[0])
        output = F.log_softmax(output, dim = 1)
        return output, hidden, attention_weights
        
    def initHidden(self):
        return torch.zeros(self.num_layers*2,1,self.hidden_dims, device = device)