import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from buffer import ReplayBuffer

class QRDQN(nn.Module):
    def __init__(self, state_dim, action_num, hidden_dim, quantile_dim):
        super(QRDQN, self).__init__()
        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, action_num * quantile_dim)

        self.action_num = action_num
        self.quantile_dim  = quantile_dim
    def forward(self, s):
        x = s
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        out = self.out(x).view(-1, self.action_num, self.quantile_dim)
        sort_out, _ = torch.sort(out, dim=-1)
        return sort_out

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


class Policy:
    def __init__(self, state_dim, action_num, hidden_dim, quantile_dim, lr, discount, cvar_alpha):
        self.q_net = QRDQN(state_dim, action_num, hidden_dim, quantile_dim)
        self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.target_q_net = QRDQN(state_dim, action_num, hidden_dim, quantile_dim)

        self.discount = discount
        self.cvar_alpha = cvar_alpha
        self.action_num = action_num
        self.quantile_dim = quantile_dim

        self.replay_buffer = ReplayBuffer(state_dim, 1)
        self.soft_update_tau = 0.005

    def get_quantiles(self, x):
        return self.q_net(x)

    def get_action(self, state):
        with torch.no_grad():
            quantiles = self.q_net(state)
            idx = int(self.quantile_dim * self.cvar_alpha)
            cvar = quantiles[:,:,:idx]  # [1, 4, 8]
            cvar = cvar.mean(-1)
            cvar = cvar.squeeze(0)
            action = cvar.max(-1)[1]
        return action.numpy()

    def qr_loss(self, curr_v, target_v):
        n = self.quantile_dim
        target_v = target_v.view(-1, n, 1).expand(-1, n, n)
        curr_v = curr_v.view(-1, 1, n).expand(-1, n, n)

        tau = torch.arange(0.5 * (1 / n), 1, 1 / n).view(1, n)
        error_loss = target_v - curr_v
        huber_loss = F.smooth_l1_loss(curr_v, target_v)
        value_loss = (tau - (error_loss < 0).float()).abs() * huber_loss
        value_loss = value_loss.mean(dim=2).sum(dim=1).mean()

        return value_loss

    def get_next_action(self, next_state):
        with torch.no_grad():
            next_quantiles = self.target_q_net(next_state)

        idx = int(self.quantile_dim * self.cvar_alpha)
        cvar = next_quantiles[:,:,:idx]    # [bs, 4, 8]
        cvar = cvar.mean(-1)

        next_action = cvar.max(1)[1]
        return next_action

    def train(self, bs):
        # k [128, 1], action [128, 1]
        state, action, reward, next_state, not_done = self.replay_buffer.sample(bs)
        #print('state', state.shape, 'action', action.shape, 'reward', reward.shape, 'next_state', next_state.shape, 'not_don', not_done.shape)
        
        curr_quantiles = self.q_net(state)
        action = action.unsqueeze(1).expand(-1 ,1 , self.quantile_dim)
        curr_quantiles = curr_quantiles.gather(1, action).squeeze(1)  # [128, 64]
        
        with torch.no_grad():
            next_quantiles = self.target_q_net(next_state)
        next_action = self.get_next_action(next_state) # [128]
        next_action = next_action.unsqueeze(1).unsqueeze(1).expand(-1, 1, self.quantile_dim)
        next_quantiles = next_quantiles.gather(1, next_action).squeeze(1)

        target_quantiles = reward + self.discount * not_done * next_quantiles # [128, 64]
        #print('targe q', target_quantiles.requires_grad)
        td_loss = self.qr_loss(curr_quantiles, target_quantiles)
        
        self.q_optimizer.zero_grad()
        td_loss.backward()
        self.q_optimizer.step()
        soft_update(self.target_q_net, self.q_net, self.soft_update_tau)


