import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
    
class ActorNet(nn.Module):
    def __init__(self, state_size, n_action, hidden_size, temp):
        super(ActorNet, self).__init__()
        self.layer1 = nn.Linear(state_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, n_action)
        self.temp = temp
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        out = self.output(x) 
        out = out / self.temp

        return F.softmax(out, dim=-1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Policy():
    def __init__(self, state_size, n_action, hidden_size, pi_temp, alpha, actor_lr, discount, n_episodes):
        self.n_action = n_action

        self.actor_net = ActorNet(state_size, n_action, hidden_size, pi_temp).to(device)
        self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=actor_lr)
        
        self.alpha = alpha
        self.discount = discount
        self.n_episodes = n_episodes
        self.name = 'GCVaR'

        # buffer
        self.state_buf, self.action_buf, self.reward_buf, self.done_buf = [], [], [], []

    def select_action(self, state):
        #get action probs then randomly sample from the probabilities
        with torch.no_grad():
            input_state = torch.FloatTensor(state).to(device) # [8]
            action_probs = self.actor_net(input_state)
            # use np.random.choice() to select action
            action_probs = action_probs.cpu().numpy()
            action = np.random.choice(np.arange(self.n_action), p=action_probs)
            action_prob = action_probs[action]
        return action, action_prob
    
    def eval_action(self, state):
        with torch.no_grad():
            input_state = torch.FloatTensor(state).to(device)
            action_probs = self.actor_net(input_state)
            action_probs = action_probs.cpu().numpy()
            action = np.argmax(action_probs)
            action_prob = action_probs[action]
        return action, action_prob
    
    def put_data(self, state_lst, action_lst, reward_lst, done_lst):
        self.state_buf.append(state_lst)
        self.action_buf.append(action_lst)
        self.reward_buf.append(reward_lst)
        self.done_buf.append(done_lst)

    def get_quantile(self, ret_lst):
        return np.quantile(ret_lst, self.alpha)

    def GCVaR(self):
        ''' CVaR policy gradient '''
        # calculate return of each traj
        ret_lst = []
        for i in range(self.n_episodes):
            reward_lst = self.reward_buf[i]
            reward_lst_ = copy.deepcopy(reward_lst)
            reward_lst_.reverse()

            ret = 0.
            for t in range(len(reward_lst)):
                ret = reward_lst_[t] + self.discount * ret
            
            ret_lst.append(ret)

        # sort return
        ret_lst = np.array(ret_lst)
        sort_ret = np.sort(ret_lst)
        sort_idx = np.argsort(ret_lst)

        # sample size is alpha * batch_size
        choose_size = int(self.n_episodes * self.alpha)
        quantile_alpha = self.get_quantile(ret_lst)

        # calculate policy gradient
        Rtau_mq_sum_logpi_lst = []
        for ci in range(choose_size):
            state_lst = self.state_buf[ sort_idx[ci] ]
            action_lst = self.action_buf[ sort_idx[ci] ]

            # create tensor
            state_t = torch.FloatTensor(np.array(state_lst)).to(device)
            action_t = torch.LongTensor(action_lst).to(device).view(-1,1)
            # we do not need the last (state) for policy gradient
            state_t = state_t[:-1, :]

            # compute sum logpi
            action_prob = self.actor_net(state_t).gather(1, action_t) # [traj_len, 1]
            sum_logpi = torch.log(action_prob).sum(dim=0)  # [1]

            R_tau = sort_ret[ci]
            assert R_tau <= quantile_alpha
            # CVaR gradient
            Rtau_mq_sum_logpi_lst.append((R_tau - quantile_alpha) * sum_logpi)

        # update policy
        cvar_grad = torch.cat(Rtau_mq_sum_logpi_lst)       # [6]
        actor_loss = - cvar_grad.mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        ''' clean buffer '''
        self.state_buf, self.action_buf, self.reward_buf, self.done_buf = [], [], [], []
        


